From 972b4ed024fbf5657467196832017ac9380da880 Mon Sep 17 00:00:00 2001 From: Gaoxiang Luo Date: Thu, 8 Aug 2024 20:14:34 -0700 Subject: [PATCH] Fix message history limiter for tool call (#3178) * fix: message history limiter to support tool calls * add: pytest and docs for message history limiter for tool calls * Added keep_first_message for HistoryLimiter transform * Update to inbetween to between * Updated keep_first_message to non-optional, logic for history limiter * Update transforms.py * Update test_transforms to match utils introduction, add keep_first_message testing * Update test_transforms.py for pre-commit checks --------- Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com> Co-authored-by: Chi Wang --- .../contrib/capabilities/transforms.py | 30 ++++- .../contrib/capabilities/test_transforms.py | 115 ++++++++++++------ .../intro_to_transform_messages.md | 23 +++- 3 files changed, 128 insertions(+), 40 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index dad3fc335edf..7cd7fdb92a35 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -53,13 +53,16 @@ class MessageHistoryLimiter: It trims the conversation history by removing older messages, retaining only the most recent messages. """ - def __init__(self, max_messages: Optional[int] = None): + def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool = False): """ Args: max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None. + keep_first_message bool: Whether to keep the original first message in the conversation history. + Defaults to False. """ self._validate_max_messages(max_messages) self._max_messages = max_messages + self._keep_first_message = keep_first_message def apply_transform(self, messages: List[Dict]) -> List[Dict]: """Truncates the conversation history to the specified maximum number of messages. @@ -75,10 +78,31 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: List[Dict]: A new list containing the most recent messages up to the specified maximum. """ - if self._max_messages is None: + if self._max_messages is None or len(messages) <= self._max_messages: return messages - return messages[-self._max_messages :] + truncated_messages = [] + remaining_count = self._max_messages + + # Start with the first message if we need to keep it + if self._keep_first_message: + truncated_messages = [messages[0]] + remaining_count -= 1 + + # Loop through messages in reverse + for i in range(len(messages) - 1, 0, -1): + if remaining_count > 1: + truncated_messages.insert(1 if self._keep_first_message else 0, messages[i]) + if remaining_count == 1: + # If there's only 1 slot left and it's a 'tools' message, ignore it. + if messages[i].get("role") != "tool": + truncated_messages.insert(1, messages[i]) + + remaining_count -= 1 + if remaining_count == 0: + break + + return truncated_messages def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_len = len(pre_transform_messages) diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 46c61d9adc6f..34094a0008b7 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -9,8 +9,8 @@ MessageHistoryLimiter, MessageTokenLimiter, TextMessageCompressor, - _count_tokens, ) +from autogen.agentchat.contrib.capabilities.transforms_util import count_text_tokens class _MockTextCompressor: @@ -40,6 +40,26 @@ def get_no_content_messages() -> List[Dict]: return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] +def get_tool_messages() -> List[Dict]: + return [ + {"role": "user", "content": "hello"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + {"role": "user", "content": "how are you"}, + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, + ] + + +def get_tool_messages_kept() -> List[Dict]: + return [ + {"role": "user", "content": "hello"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + ] + + def get_text_compressors() -> List[TextCompressor]: compressors: List[TextCompressor] = [_MockTextCompressor()] try: @@ -57,6 +77,11 @@ def message_history_limiter() -> MessageHistoryLimiter: return MessageHistoryLimiter(max_messages=3) +@pytest.fixture +def message_history_limiter_keep_first() -> MessageHistoryLimiter: + return MessageHistoryLimiter(max_messages=3, keep_first_message=True) + + @pytest.fixture def message_token_limiter() -> MessageTokenLimiter: return MessageTokenLimiter(max_tokens_per_message=3) @@ -96,12 +121,43 @@ def _filter_dict_test( @pytest.mark.parametrize( "messages, expected_messages_len", - [(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)], + [ + (get_long_messages(), 3), + (get_short_messages(), 3), + (get_no_content_messages(), 2), + (get_tool_messages(), 2), + (get_tool_messages_kept(), 2), + ], ) def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len): transformed_messages = message_history_limiter.apply_transform(messages) assert len(transformed_messages) == expected_messages_len + if messages == get_tool_messages_kept(): + assert transformed_messages[0]["role"] == "tool_calls" + assert transformed_messages[1]["role"] == "tool" + + +@pytest.mark.parametrize( + "messages, expected_messages_len", + [ + (get_long_messages(), 3), + (get_short_messages(), 3), + (get_no_content_messages(), 2), + (get_tool_messages(), 3), + (get_tool_messages_kept(), 3), + ], +) +def test_message_history_limiter_apply_transform_keep_first( + message_history_limiter_keep_first, messages, expected_messages_len +): + transformed_messages = message_history_limiter_keep_first.apply_transform(messages) + assert len(transformed_messages) == expected_messages_len + + if messages == get_tool_messages_kept(): + assert transformed_messages[1]["role"] == "tool_calls" + assert transformed_messages[2]["role"] == "tool" + @pytest.mark.parametrize( "messages, expected_logs, expected_effect", @@ -109,6 +165,8 @@ def test_message_history_limiter_apply_transform(message_history_limiter, messag (get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), (get_short_messages(), "No messages were removed.", False), (get_no_content_messages(), "No messages were removed.", False), + (get_tool_messages(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True), + (get_tool_messages_kept(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True), ], ) def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect): @@ -131,7 +189,8 @@ def test_message_token_limiter_apply_transform( ): transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages)) assert ( - sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) + == expected_token_count ) assert len(transformed_messages) == expected_messages_len @@ -167,7 +226,8 @@ def test_message_token_limiter_with_threshold_apply_transform( ): transformed_messages = message_token_limiter_with_threshold.apply_transform(messages) assert ( - sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) + == expected_token_count ) assert len(transformed_messages) == expected_messages_len @@ -240,56 +300,31 @@ def test_text_compression_with_filter(messages, text_compressor): assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False) -@pytest.mark.parametrize("text_compressor", get_text_compressors()) -def test_text_compression_cache(text_compressor): - messages = get_long_messages() - mock_compressed_content = (1, {"content": "mock"}) - - with patch( - "autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_get", - MagicMock(return_value=(1, {"content": "mock"})), - ) as mocked_get, patch( - "autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock() - ) as mocked_set: - compressor = TextMessageCompressor(text_compressor=text_compressor) - - compressor.apply_transform(messages) - compressor.apply_transform(messages) - - assert mocked_get.call_count == len(messages) - assert mocked_set.call_count == len(messages) - - # We already populated the cache with the mock content - # We need to test if we retrieve the correct content - compressor = TextMessageCompressor(text_compressor=text_compressor) - compressed_messages = compressor.apply_transform(messages) - - for message in compressed_messages: - assert message["content"] == mock_compressed_content[1] - - if __name__ == "__main__": long_messages = get_long_messages() short_messages = get_short_messages() no_content_messages = get_no_content_messages() + tool_messages = get_tool_messages() msg_history_limiter = MessageHistoryLimiter(max_messages=3) + msg_history_limiter_keep_first = MessageHistoryLimiter(max_messages=3, keep_first=True) msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10) # Test Parameters message_history_limiter_apply_transform_parameters = { - "messages": [long_messages, short_messages, no_content_messages], - "expected_messages_len": [3, 3, 2], + "messages": [long_messages, short_messages, no_content_messages, tool_messages], + "expected_messages_len": [3, 3, 2, 4], } message_history_limiter_get_logs_parameters = { - "messages": [long_messages, short_messages, no_content_messages], + "messages": [long_messages, short_messages, no_content_messages, tool_messages], "expected_logs": [ "Removed 2 messages. Number of messages reduced from 5 to 3.", "No messages were removed.", "No messages were removed.", + "Removed 1 messages. Number of messages reduced from 5 to 4.", ], - "expected_effect": [True, False, False], + "expected_effect": [True, False, False, True], } message_token_limiter_apply_transform_parameters = { @@ -322,6 +357,14 @@ def test_text_compression_cache(text_compressor): ): test_message_history_limiter_apply_transform(msg_history_limiter, messages, expected_messages_len) + for messages, expected_messages_len in zip( + message_history_limiter_apply_transform_parameters["messages"], + message_history_limiter_apply_transform_parameters["expected_messages_len"], + ): + test_message_history_limiter_apply_transform_keep_first( + msg_history_limiter_keep_first, messages, expected_messages_len + ) + for messages, expected_logs, expected_effect in zip( message_history_limiter_get_logs_parameters["messages"], message_history_limiter_get_logs_parameters["expected_logs"], diff --git a/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md b/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md index d0a53702c48b..52fea15d01e5 100644 --- a/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md +++ b/website/docs/topics/handling_long_contexts/intro_to_transform_messages.md @@ -59,7 +59,28 @@ pprint.pprint(processed_messages) {'content': 'very very very very very very long string', 'role': 'user'}] ``` -By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. +By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. However, if the splitting point is between a "tool_calls" and "tool" pair, the complete pair will be included to obey the OpenAI API call constraints. + +```python +max_msg_transfrom = transforms.MessageHistoryLimiter(max_messages=3) + +messages = [ + {"role": "user", "content": "hello"}, + {"role": "tool_calls", "content": "calling_tool"}, + {"role": "tool", "content": "tool_response"}, + {"role": "user", "content": "how are you"}, + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, +] + +processed_messages = max_msg_transfrom.apply_transform(copy.deepcopy(messages)) +pprint.pprint(processed_messages) +``` +```console +[{'content': 'calling_tool', 'role': 'tool_calls'}, +{'content': 'tool_response', 'role': 'tool'}, +{'content': 'how are you', 'role': 'user'}, +{'content': [{'text': 'are you doing?', 'type': 'text'}], 'role': 'assistant'}] +``` #### Example 2: Limiting the Number of Tokens