Skip to content

Commit

Permalink
Initial commit with ability to add name into content with a transform
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze committed Aug 9, 2024
1 parent 972b4ed commit 7c98bc1
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 0 deletions.
92 changes: 92 additions & 0 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,95 @@ def _compress_text(self, text: str) -> Tuple[str, int]:
def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


class TextMessageContentName:
"""A transform for including the agent's name in the content of a message."""

def __init__(
self,
position: str = "start",
format_string: str = "{name}:\n",
deduplicate: bool = True,
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""

assert isinstance(position, str) and position is not None
assert position in ["start", "end"]
assert isinstance(format_string, str) and format_string is not None
assert "{name}" in format_string
assert isinstance(deduplicate, bool) and deduplicate is not None

self._position = position
self._format_string = format_string
self._deduplicate = deduplicate
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter

# Track the number of messages changed for logging
self._messages_changed = 0

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies the name change to the message based on the position and format string.
Args:
messages (List[Dict]): A list of message dictionaries.
Returns:
List[Dict]: A list of dictionaries with the message content updated with names.
"""
# Make sure there is at least one message
if not messages:
return messages

messages_changed = 0
processed_messages = copy.deepcopy(messages)
for message in processed_messages:
# Some messages may not have content.
if not transforms_util.is_content_right_type(
message.get("content")
) or not transforms_util.is_content_right_type(message.get("name")):
continue

if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue

if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
message["name"]
):
continue

# Get and format the name in the content
content = message["content"]
formatted_name = self._format_string.format(name=message["name"])

if self._position == "start":
if not self._deduplicate or not content.startswith(formatted_name):
message["content"] = f"{formatted_name}{content}"

messages_changed += 1
else:
if not self._deduplicate or not content.endswith(formatted_name):
message["content"] = f"{content}{formatted_name}"

messages_changed += 1

self._messages_changed = messages_changed
return processed_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
if self._messages_changed > 0:
return f"{self._messages_changed} message(s) changed to incorporate name.", True
else:
return "No messages changed to incorporate name.", False
94 changes: 94 additions & 0 deletions test/agentchat/contrib/capabilities/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MessageHistoryLimiter,
MessageTokenLimiter,
TextMessageCompressor,
TextMessageContentName,
)
from autogen.agentchat.contrib.capabilities.transforms_util import count_text_tokens

Expand Down Expand Up @@ -60,6 +61,42 @@ def get_tool_messages_kept() -> List[Dict]:
]


def get_messages_with_names() -> List[Dict]:
return [
{"role": "system", "content": "I am the system."},
{"role": "user", "name": "charlie", "content": "I think the sky is blue."},
{"role": "user", "name": "mary", "content": "The sky is red."},
{"role": "user", "name": "bob", "content": "The sky is crimson."},
]


def get_messages_with_names_post_start() -> List[Dict]:
return [
{"role": "system", "content": "I am the system."},
{"role": "user", "name": "charlie", "content": "'charlie' said:\nI think the sky is blue."},
{"role": "user", "name": "mary", "content": "'mary' said:\nThe sky is red."},
{"role": "user", "name": "bob", "content": "'bob' said:\nThe sky is crimson."},
]


def get_messages_with_names_post_end() -> List[Dict]:
return [
{"role": "system", "content": "I am the system."},
{"role": "user", "name": "charlie", "content": "I think the sky is blue.\n(said 'charlie')"},
{"role": "user", "name": "mary", "content": "The sky is red.\n(said 'mary')"},
{"role": "user", "name": "bob", "content": "The sky is crimson.\n(said 'bob')"},
]


def get_messages_with_names_post_filtered() -> List[Dict]:
return [
{"role": "system", "content": "I am the system."},
{"role": "user", "name": "charlie", "content": "I think the sky is blue."},
{"role": "user", "name": "mary", "content": "'mary' said:\nThe sky is red."},
{"role": "user", "name": "bob", "content": "'bob' said:\nThe sky is crimson."},
]


def get_text_compressors() -> List[TextCompressor]:
compressors: List[TextCompressor] = [_MockTextCompressor()]
try:
Expand Down Expand Up @@ -300,6 +337,63 @@ def test_text_compression_with_filter(messages, text_compressor):
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)


@pytest.mark.parametrize("messages", [get_messages_with_names()])
def test_message_content_name(messages):
# Test including content name in messages

# Add name at the start with format: "'{name}' said:\n"
content_transform = TextMessageContentName(position="start", format_string="'{name}' said:\n")
transformed_messages = content_transform.apply_transform(messages=messages)

assert transformed_messages == get_messages_with_names_post_start()

# Add name at the end with format: "\n(said '{name}')"
content_transform = TextMessageContentName(position="end", format_string="\n(said '{name}')")
transformed_messages_end = content_transform.apply_transform(messages=messages)

assert transformed_messages_end == get_messages_with_names_post_end()

# Test filtering out exclusion
content_transform = TextMessageContentName(
position="start",
format_string="'{name}' said:\n",
filter_dict={"name": ["charlie"]},
exclude_filter=True, # Exclude
)

transformed_messages_end = content_transform.apply_transform(messages=messages)

assert transformed_messages_end == get_messages_with_names_post_filtered()

# Test filtering (inclusion)
content_transform = TextMessageContentName(
position="start",
format_string="'{name}' said:\n",
filter_dict={"name": ["mary", "bob"]},
exclude_filter=False, # Include
)

transformed_messages_end = content_transform.apply_transform(messages=messages)

assert transformed_messages_end == get_messages_with_names_post_filtered()

# Test instantiation
with pytest.raises(AssertionError):
TextMessageContentName(position=123) # Invalid type for position

with pytest.raises(AssertionError):
TextMessageContentName(position="middle") # Invalid value for position

with pytest.raises(AssertionError):
TextMessageContentName(format_string=123) # Invalid type for format_string

with pytest.raises(AssertionError):
TextMessageContentName(format_string="Agent:\n") # Missing '{name}' in format_string

with pytest.raises(AssertionError):
TextMessageContentName(deduplicate="yes") # Invalid type for deduplicate


if __name__ == "__main__":
long_messages = get_long_messages()
short_messages = get_short_messages()
Expand Down

0 comments on commit 7c98bc1

Please sign in to comment.