-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transform Messages Capability (#1923)
* wip * Adds docstrings * fixed spellings * wip * fixed errors * better class names * adds tests * added tests to workflow * improved token counting * improved notebook * improved token counting in test * improved docstrings * fix inconsistencies * changed by mistake * fixed docstring * fixed details * improves tests + adds openai contrib test * fix spelling oai contrib test * clearer docstrings * remove repeated docstr * improved notebook * adds metadata to notebook * Improve outline and description (#2125) * better dir structure * clip max tokens to allowed tokens * more accurate comments/docstrs * add deperecation warning * fix front matter desc * add deperecation warning notebook * undo local notebook settings changes * format notebook * format workflow --------- Co-authored-by: gagb <[email protected]>
- Loading branch information
1 parent
6fbb9e8
commit 8cb5c14
Showing
7 changed files
with
1,010 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 13 additions & 4 deletions
17
autogen/agentchat/contrib/capabilities/context_handling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
autogen/agentchat/contrib/capabilities/transform_messages.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import copy | ||
from typing import Dict, List | ||
|
||
from termcolor import colored | ||
|
||
from autogen import ConversableAgent | ||
from .transforms import MessageTransform | ||
|
||
|
||
class TransformMessages: | ||
"""Agent capability for transforming messages before reply generation. | ||
This capability allows you to apply a series of message transformations to | ||
a ConversableAgent's incoming messages before they are processed for response | ||
generation. This is useful for tasks such as: | ||
- Limiting the number of messages considered for context. | ||
- Truncating messages to meet token limits. | ||
- Filtering sensitive information. | ||
- Customizing message formatting. | ||
To use `TransformMessages`: | ||
1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`). | ||
2. Instantiate `TransformMessages` with a list of these transformations. | ||
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`. | ||
NOTE: Order of message transformations is important. You could get different results based on | ||
the order of transformations. | ||
Example: | ||
```python | ||
from agentchat import ConversableAgent | ||
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter | ||
max_messages = MessageHistoryLimiter(max_messages=2) | ||
truncate_messages = MessageTokenLimiter(max_tokens=500) | ||
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages]) | ||
agent = ConversableAgent(...) | ||
transform_messages.add_to_agent(agent) | ||
``` | ||
""" | ||
|
||
def __init__(self, *, transforms: List[MessageTransform] = []): | ||
""" | ||
Args: | ||
transforms: A list of message transformations to apply. | ||
""" | ||
self._transforms = transforms | ||
|
||
def add_to_agent(self, agent: ConversableAgent): | ||
"""Adds the message transformations capability to the specified ConversableAgent. | ||
This function performs the following modifications to the agent: | ||
1. Registers a hook that automatically transforms all messages before they are processed for | ||
response generation. | ||
""" | ||
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) | ||
|
||
def _transform_messages(self, messages: List[Dict]) -> List[Dict]: | ||
temp_messages = copy.deepcopy(messages) | ||
system_message = None | ||
|
||
if messages[0]["role"] == "system": | ||
system_message = copy.deepcopy(messages[0]) | ||
temp_messages.pop(0) | ||
|
||
for transform in self._transforms: | ||
temp_messages = transform.apply_transform(temp_messages) | ||
|
||
if system_message: | ||
temp_messages.insert(0, system_message) | ||
|
||
self._print_stats(messages, temp_messages) | ||
|
||
return temp_messages | ||
|
||
def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): | ||
pre_transform_messages_len = len(pre_transform_messages) | ||
post_transform_messages_len = len(post_transform_messages) | ||
|
||
if pre_transform_messages_len < post_transform_messages_len: | ||
print( | ||
colored( | ||
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.", | ||
"yellow", | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
import sys | ||
from typing import Any, Dict, List, Optional, Protocol, Union | ||
|
||
import tiktoken | ||
from termcolor import colored | ||
|
||
from autogen import token_count_utils | ||
|
||
|
||
class MessageTransform(Protocol): | ||
"""Defines a contract for message transformation. | ||
Classes implementing this protocol should provide an `apply_transform` method | ||
that takes a list of messages and returns the transformed list. | ||
""" | ||
|
||
def apply_transform(self, messages: List[Dict]) -> List[Dict]: | ||
"""Applies a transformation to a list of messages. | ||
Args: | ||
messages: A list of dictionaries representing messages. | ||
Returns: | ||
A new list of dictionaries containing the transformed messages. | ||
""" | ||
... | ||
|
||
|
||
class MessageHistoryLimiter: | ||
"""Limits the number of messages considered by an agent for response generation. | ||
This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages). | ||
It trims the conversation history by removing older messages, retaining only the most recent messages. | ||
""" | ||
|
||
def __init__(self, max_messages: Optional[int] = None): | ||
""" | ||
Args: | ||
max_messages (None or int): Maximum number of messages to keep in the context. | ||
Must be greater than 0 if not None. | ||
""" | ||
self._validate_max_messages(max_messages) | ||
self._max_messages = max_messages | ||
|
||
def apply_transform(self, messages: List[Dict]) -> List[Dict]: | ||
"""Truncates the conversation history to the specified maximum number of messages. | ||
This method returns a new list containing the most recent messages up to the specified | ||
maximum number of messages (max_messages). If max_messages is None, it returns the | ||
original list of messages unmodified. | ||
Args: | ||
messages (List[Dict]): The list of messages representing the conversation history. | ||
Returns: | ||
List[Dict]: A new list containing the most recent messages up to the specified maximum. | ||
""" | ||
if self._max_messages is None: | ||
return messages | ||
|
||
return messages[-self._max_messages :] | ||
|
||
def _validate_max_messages(self, max_messages: Optional[int]): | ||
if max_messages is not None and max_messages < 1: | ||
raise ValueError("max_messages must be None or greater than 1") | ||
|
||
|
||
class MessageTokenLimiter: | ||
"""Truncates messages to meet token limits for efficient processing and response generation. | ||
This transformation applies two levels of truncation to the conversation history: | ||
1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message. | ||
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens. | ||
NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token | ||
counts for the same text. | ||
NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input | ||
(e.g images). | ||
The truncation process follows these steps in order: | ||
1. Messages are processed in reverse order (newest to oldest). | ||
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text | ||
and other types of content, only the text content is truncated. | ||
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count | ||
exceeds this limit, the current message being processed as well as any remaining messages are discarded. | ||
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the | ||
original message order. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
max_tokens_per_message: Optional[int] = None, | ||
max_tokens: Optional[int] = None, | ||
model: str = "gpt-3.5-turbo-0613", | ||
): | ||
""" | ||
Args: | ||
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message. | ||
Must be greater than or equal to 0 if not None. | ||
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history. | ||
Must be greater than or equal to 0 if not None. | ||
model (str): The target OpenAI model for tokenization alignment. | ||
""" | ||
self._model = model | ||
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message) | ||
self._max_tokens = self._validate_max_tokens(max_tokens) | ||
|
||
def apply_transform(self, messages: List[Dict]) -> List[Dict]: | ||
"""Applies token truncation to the conversation history. | ||
Args: | ||
messages (List[Dict]): The list of messages representing the conversation history. | ||
Returns: | ||
List[Dict]: A new list containing the truncated messages up to the specified token limits. | ||
""" | ||
assert self._max_tokens_per_message is not None | ||
assert self._max_tokens is not None | ||
|
||
temp_messages = messages.copy() | ||
processed_messages = [] | ||
processed_messages_tokens = 0 | ||
|
||
# calculate tokens for all messages | ||
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages) | ||
|
||
for msg in reversed(temp_messages): | ||
msg["content"] = self._truncate_str_to_tokens(msg["content"]) | ||
msg_tokens = _count_tokens(msg["content"]) | ||
|
||
# If adding this message would exceed the token limit, discard it and all remaining messages | ||
if processed_messages_tokens + msg_tokens > self._max_tokens: | ||
break | ||
|
||
# prepend the message to the list to preserve order | ||
processed_messages_tokens += msg_tokens | ||
processed_messages.insert(0, msg) | ||
|
||
if total_tokens > processed_messages_tokens: | ||
print( | ||
colored( | ||
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}", | ||
"yellow", | ||
) | ||
) | ||
|
||
return processed_messages | ||
|
||
def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]: | ||
if isinstance(contents, str): | ||
return self._truncate_tokens(contents) | ||
elif isinstance(contents, list): | ||
return self._truncate_multimodal_text(contents) | ||
else: | ||
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}") | ||
|
||
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||
"""Truncates text content within a list of multimodal elements, preserving the overall structure.""" | ||
tmp_contents = [] | ||
for content in contents: | ||
if content["type"] == "text": | ||
truncated_text = self._truncate_tokens(content["text"]) | ||
tmp_contents.append({"type": "text", "text": truncated_text}) | ||
else: | ||
tmp_contents.append(content) | ||
return tmp_contents | ||
|
||
def _truncate_tokens(self, text: str) -> str: | ||
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer | ||
|
||
encoded_tokens = encoding.encode(text) | ||
truncated_tokens = encoded_tokens[: self._max_tokens_per_message] | ||
truncated_text = encoding.decode(truncated_tokens) # Decode back to text | ||
|
||
return truncated_text | ||
|
||
def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]: | ||
if max_tokens is not None and max_tokens < 0: | ||
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0") | ||
|
||
try: | ||
allowed_tokens = token_count_utils.get_max_token_limit(self._model) | ||
except Exception: | ||
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow")) | ||
allowed_tokens = None | ||
|
||
if max_tokens is not None and allowed_tokens is not None: | ||
if max_tokens > allowed_tokens: | ||
print( | ||
colored( | ||
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.", | ||
"yellow", | ||
) | ||
) | ||
return allowed_tokens | ||
|
||
return max_tokens if max_tokens is not None else sys.maxsize | ||
|
||
|
||
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: | ||
token_count = 0 | ||
if isinstance(content, str): | ||
token_count = token_count_utils.count_token(content) | ||
elif isinstance(content, list): | ||
for item in content: | ||
token_count += _count_tokens(item.get("text", "")) | ||
return token_count |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.