Skip to content

Commit

Permalink
Standardize printing of MessageTransforms (microsoft#2308)
Browse files Browse the repository at this point in the history
* Standardize printing of MessageTransforms

* Fix pre-commit black failure

* Add test for transform_messages printing

* Return str instead of printing

* Rename to_print_stats to verbose

* Cleanup

* t i# This is a combination of 3 commits.

Update requirements

* Remove lazy-fixture

* Avoid calling apply_transform in two code paths

* Format

* Replace stats with logs

* Handle no content messages in TokenLimiter get_logs()

* Move tests from test_transform_messages to test_transforms

---------

Co-authored-by: Wael Karkoub <[email protected]>
  • Loading branch information
giorgossideris and WaelKarkoub authored Apr 14, 2024
1 parent d473dee commit 9088390
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 141 deletions.
40 changes: 18 additions & 22 deletions autogen/agentchat/contrib/capabilities/transform_messages.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import copy
from typing import Dict, List

from termcolor import colored

from autogen import ConversableAgent

from ....formatting_utils import colored
from .transforms import MessageTransform


Expand Down Expand Up @@ -43,12 +42,14 @@ class TransformMessages:
```
"""

def __init__(self, *, transforms: List[MessageTransform] = []):
def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True):
"""
Args:
transforms: A list of message transformations to apply.
verbose: Whether to print logs of each transformation or not.
"""
self._transforms = transforms
self._verbose = verbose

def add_to_agent(self, agent: ConversableAgent):
"""Adds the message transformations capability to the specified ConversableAgent.
Expand All @@ -61,31 +62,26 @@ def add_to_agent(self, agent: ConversableAgent):
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
temp_messages = copy.deepcopy(messages)
post_transform_messages = copy.deepcopy(messages)
system_message = None

if messages[0]["role"] == "system":
system_message = copy.deepcopy(messages[0])
temp_messages.pop(0)
post_transform_messages.pop(0)

for transform in self._transforms:
temp_messages = transform.apply_transform(temp_messages)

if system_message:
temp_messages.insert(0, system_message)

self._print_stats(messages, temp_messages)
# deepcopy in case pre_transform_messages will later be used for logs printing
pre_transform_messages = (
copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
)
post_transform_messages = transform.apply_transform(pre_transform_messages)

return temp_messages
if self._verbose:
logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
if had_effect:
print(colored(logs_str, "yellow"))

def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]):
pre_transform_messages_len = len(pre_transform_messages)
post_transform_messages_len = len(post_transform_messages)
if system_message:
post_transform_messages.insert(0, system_message)

if pre_transform_messages_len < post_transform_messages_len:
print(
colored(
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.",
"yellow",
)
)
return post_transform_messages
60 changes: 45 additions & 15 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import sys
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import tiktoken
from termcolor import colored
Expand All @@ -25,6 +26,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""
...

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
"""Creates the string including the logs of the transformation
Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
Args:
pre_transform_messages: A list of dictionaries representing messages before the transformation.
post_transform_messages: A list of dictionaries representig messages after the transformation.
Returns:
A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
"""
...


class MessageHistoryLimiter:
"""Limits the number of messages considered by an agent for response generation.
Expand Down Expand Up @@ -60,6 +75,18 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

return messages[-self._max_messages :]

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
post_transform_messages_len = len(post_transform_messages)

if post_transform_messages_len < pre_transform_messages_len:
logs_str = (
f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
)
return logs_str, True
return "No messages were removed.", False

def _validate_max_messages(self, max_messages: Optional[int]):
if max_messages is not None and max_messages < 1:
raise ValueError("max_messages must be None or greater than 1")
Expand Down Expand Up @@ -121,15 +148,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None

temp_messages = messages.copy()
temp_messages = copy.deepcopy(messages)
processed_messages = []
processed_messages_tokens = 0

# calculate tokens for all messages
total_tokens = sum(
_count_tokens(msg["content"]) for msg in temp_messages if isinstance(msg.get("content"), (str, list))
)

for msg in reversed(temp_messages):
# Some messages may not have content.
if not isinstance(msg.get("content"), (str, list)):
Expand All @@ -154,16 +176,24 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)

if total_tokens > processed_messages_tokens:
print(
colored(
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
"yellow",
)
)

return processed_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
)
post_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
)

if post_transform_messages_tokens < pre_transform_messages_tokens:
logs_str = (
f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
)
return logs_str, True
return "No tokens were truncated.", False

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down
104 changes: 0 additions & 104 deletions test/agentchat/contrib/capabilities/test_transform_messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
import sys
import tempfile
Expand All @@ -7,7 +6,6 @@
import pytest

import autogen
from autogen import token_count_utils
from autogen.agentchat.contrib.capabilities.transform_messages import TransformMessages
from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter

Expand All @@ -18,106 +16,6 @@
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def test_limit_token_transform():
"""
Test the TokenLimitTransform capability.
"""

messages = [
{"role": "user", "content": "short string"},
{
"role": "assistant",
"content": [{"type": "text", "text": "very very very very very very very very long string"}],
},
]

# check if token limit per message is not exceeded.
max_tokens_per_message = 5
token_limit_transform = MessageTokenLimiter(max_tokens_per_message=max_tokens_per_message)
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))

for message in transformed_messages:
assert _count_tokens(message["content"]) <= max_tokens_per_message

# check if total token limit is not exceeded.
max_tokens = 10
token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens)
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))

token_count = 0
for message in transformed_messages:
token_count += _count_tokens(message["content"])

assert token_count <= max_tokens
assert len(transformed_messages) <= len(messages)

# check if token limit per message works nicely with total token limit.
token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens, max_tokens_per_message=max_tokens_per_message)

transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))

token_count = 0
for message in transformed_messages:
token_count_local = _count_tokens(message["content"])
token_count += token_count_local
assert token_count_local <= max_tokens_per_message

assert token_count <= max_tokens
assert len(transformed_messages) <= len(messages)


def test_limit_token_transform_without_content():
"""Test the TokenLimitTransform with messages that don't have content."""

messages = [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]

# check if token limit per message works nicely with total token limit.
token_limit_transform = MessageTokenLimiter(max_tokens=10, max_tokens_per_message=5)

transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))

assert len(transformed_messages) == len(messages)


def test_limit_token_transform_total_token_count():
"""Tests if the TokenLimitTransform truncates without dropping messages."""
messages = [{"role": "very very very very very"}]

token_limit_transform = MessageTokenLimiter(max_tokens=1)
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))

assert len(transformed_messages) == 1


def test_max_message_history_length_transform():
"""
Test the MessageHistoryLimiter capability to limit the number of messages.
"""
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": [{"type": "text", "text": "there"}]},
{"role": "user", "content": "how"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]

max_messages = 2
messages_limiter = MessageHistoryLimiter(max_messages=max_messages)
transformed_messages = messages_limiter.apply_transform(copy.deepcopy(messages))

assert len(transformed_messages) == max_messages
assert transformed_messages == messages[max_messages:]


@pytest.mark.skipif(skip_openai, reason="Requested to skip openai test.")
def test_transform_messages_capability():
"""Test the TransformMessages capability to handle long contexts.
Expand Down Expand Up @@ -172,6 +70,4 @@ def test_transform_messages_capability():


if __name__ == "__main__":
test_limit_token_transform()
test_max_message_history_length_transform()
test_transform_messages_capability()
Loading

0 comments on commit 9088390

Please sign in to comment.