Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MCTS Reasoning Agent #175

Merged
merged 27 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from .agent import Agent
from .assistant_agent import AssistantAgent
from .chat import ChatResult, initiate_chats
from .contrib.reasoning_agent import (
ReasoningAgent,
ThinkNode,
visualize_tree,
)

# Imported last to avoid circular imports
from .contrib.swarm_agent import (
Expand Down Expand Up @@ -41,4 +46,7 @@
"AFTER_WORK",
"AfterWorkOption",
"UPDATE_SYSTEM_MESSAGE",
"ReasoningAgent",
"visualize_tree",
"ThinkNode",
]
524 changes: 451 additions & 73 deletions autogen/agentchat/contrib/reasoning_agent.py

Large diffs are not rendered by default.

7,525 changes: 5,791 additions & 1,734 deletions notebook/agentchat_reasoning_agent.ipynb

Large diffs are not rendered by default.

1,855 changes: 928 additions & 927 deletions notebook/autobuild_agent_library.ipynb

Large diffs are not rendered by default.

51 changes: 21 additions & 30 deletions test/agentchat/contrib/test_reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

# Test data
TEST_QUESTION = "What is the capital of France?"
TEST_TRAJECTORY = """# Question: What is the capital of France?
TEST_TRAJECTORY = """# Question:
What is the capital of France?
---

Step 1: Let me think about this systematically
Step 2: France is a country in Europe
Step 3: Paris is the capital city of France"""
Expand All @@ -51,7 +54,7 @@ def reasoning_agent():
def test_think_node_init(think_node):
"""Test ThinkNode initialization"""
assert think_node.content == TEST_CONTENT
assert think_node.value is None
assert think_node.value == 0
assert think_node.parent is None
assert think_node.depth == 0
assert think_node.children == []
Expand All @@ -60,13 +63,14 @@ def test_think_node_init(think_node):

def test_think_node_trajectory(think_node):
"""Test ThinkNode trajectory property"""
assert think_node._trajectory_arr == ["# Question: " + TEST_CONTENT]
assert "# Question: " + TEST_CONTENT in think_node.trajectory
first_line = "# Question:\n" + TEST_CONTENT + "\n---\n"
assert think_node._trajectory_arr == [first_line]
assert first_line in think_node.trajectory


def test_think_node_str_repr(think_node):
"""Test ThinkNode string representation"""
expected = f"{TEST_CONTENT} -> Depth: 0 Value: None Visits: 0"
expected = f"{TEST_CONTENT} -> Depth: 0 Value: 0 Visits: 0"
assert str(think_node) == expected
assert repr(think_node) == expected

Expand All @@ -75,7 +79,7 @@ def test_think_node_to_dict(think_node):
"""Test ThinkNode to_dict method"""
node_dict = think_node.to_dict()
assert node_dict["content"] == TEST_CONTENT
assert node_dict["value"] is None
assert node_dict["value"] == 0
assert node_dict["depth"] == 0
assert node_dict["visits"] == 0
assert node_dict["children"] == []
Expand All @@ -96,21 +100,12 @@ def test_think_node_from_dict():
def test_reasoning_agent_init(reasoning_agent):
"""Test ReasoningAgent initialization"""
assert reasoning_agent.name == "reasoning_agent"
assert reasoning_agent.max_depth == 4
assert reasoning_agent.beam_size == 3
assert reasoning_agent.answer_approach == "pool"
assert reasoning_agent._max_depth == 4
assert reasoning_agent._beam_size == 3
assert reasoning_agent._answer_approach == "pool"
assert reasoning_agent._root is None


def test_reasoning_agent_invalid_approach():
"""Test ReasoningAgent with invalid answer approach"""
config_list = [{"model": "gpt-4o-mini", "api_key": "fake_key"}]
llm_config = {"config_list": config_list}

with pytest.raises(AssertionError):
ReasoningAgent("reasoning_agent", llm_config=llm_config, answer_approach="invalid")


def test_think_node_with_parent():
"""Test ThinkNode parent-child relationship"""
parent = ThinkNode(content="Parent node")
Expand Down Expand Up @@ -172,9 +167,7 @@ def helper_test_reasoning_agent_answer(max_depth, beam_size, answer_approach):
agent = ReasoningAgent(
"test_agent",
llm_config=mock_config,
max_depth=max_depth,
beam_size=beam_size,
answer_approach=answer_approach,
reason_config={"beam_size": beam_size, "answer_approach": answer_approach, "max_depth": max_depth},
)

def mock_response(*args, **kwargs):
Expand All @@ -199,14 +192,12 @@ def mock_response(*args, **kwargs):

mock_oai_reply.side_effect = mock_response

print("OAI REPLY:", agent.thinker.generate_oai_reply)
print("OAI REPLY:", agent._thinker.generate_oai_reply)

success, response = agent.generate_response(
messages=[{"role": "user", "content": "Test question"}], sender=None
)
response = agent._beam_reply("Test question")
assert len(response)

assert success is True
assert "TERMINATE" in agent.thinker.last_message()["content"]
assert "TERMINATE" in agent._thinker.last_message()["content"]

# Verify we didn't exceed max_depth
current_node = agent._root
Expand All @@ -218,7 +209,7 @@ def mock_response(*args, **kwargs):
max_depth_found = max(max_depth_found, node.depth)
nodes_to_check.extend(node.children)

assert max_depth_found <= agent.max_depth
assert max_depth_found <= agent._max_depth


@patch("graphviz.Digraph")
Expand Down Expand Up @@ -252,8 +243,8 @@ def test_visualize_tree_successful_case(mock_digraph):
expected_calls = [
call("0", "Root\n visits: 1\n value: 0.5"),
call("0_0", "Child 1\n visits: 2\n value: 0.7"),
call("0_1", "Child 2\n visits: 0\n value: None"),
call("0_0_0", "Grandchild with very long content that should be t...\n visits: 0\n value: None"),
call("0_1", "Child 2\n visits: 0\n value: 0"),
call("0_0_0", "Grandchild with very long content that should be t...\n visits: 0\n value: 0"),
]
mock_graph.node.assert_has_calls(expected_calls, any_order=True)

Expand Down
116 changes: 24 additions & 92 deletions website/blog/2024-12-02-ReasoningAgent2/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
title: ReasoningAgent - Tree of Thoughts with Beam Search in AG2
authors:
- Hk669
- BabyCNM
- skzhang1
- sonichi
- BabyCNM
- qingyunwu
tags: [LLM, GPT, research]
---
Expand Down Expand Up @@ -60,8 +60,10 @@ reason_agent = ReasoningAgent(
name="reason_agent",
llm_config={"config_list": config_list},
verbose=False,
beam_size=1, # Using beam size 1 for O1-style reasoning
max_depth=3,
reason_config={
"beam_size": 1, # Using beam size 1 for O1-style reasoning
"max_depth": 3
}
)
```

Expand All @@ -74,8 +76,14 @@ Here's a simple example of using ReasoningAgent:

```python
import os
from autogen import AssistantAgent, UserProxyAgent
from autogen.agentchat.contrib.reasoning_agent import ReasoningAgent, visualize_tree
from autogen import (
AssistantAgent,
UserProxyAgent,
ReasoningAgent,
ThinkNode,
visualize_tree
)


# Configure the model
config_list = [{"model": "gpt-4", "api_key": os.environ.get("OPENAI_API_KEY")}]
Expand All @@ -85,8 +93,10 @@ reasoning_agent = ReasoningAgent(
name="reason_agent",
llm_config={"config_list": config_list},
verbose=False,
beam_size=1, # Using beam size 1 for O1-style reasoning
max_depth=3,
reason_config={
"beam_size": 1, # Using beam size 1 for O1-style reasoning
"max_depth": 3
}
)

# Create a user proxy agent
Expand Down Expand Up @@ -140,8 +150,10 @@ reason_agent = ReasoningAgent(
name="reason_agent",
llm_config={"config_list": config_list},
verbose=False,
beam_size=3, # Explore 3 paths in parallel
max_depth=3,
reason_config={
"beam_size": 3,
"max_depth": 3
}
)

# Example complex problem
Expand Down Expand Up @@ -180,6 +192,7 @@ After asking a question to the `ReasoningAgent`, you only need to simply call th

```python
import json

data = reasoning_agent._root.to_dict()
with open("reasoning_tree.json", "w") as f:
json.dump(data, f)
Expand All @@ -202,43 +215,7 @@ new_node = pickle.load(open("reasoning_tree.pkl", "rb"))
This step finds the best trajectory in the thought tree and converts it to a SFT dataset as a sequence of strings. The best trajectory is determined by following the highest-scoring path from root to leaf.

```python
def extract_sft_dataset(root):
"""
Extract the best trajectory or multiple equally good trajectories
for SFT training.

Args:
root: The root node of the tree.

Returns:
List of best trajectories, where each trajectory is a pair of instruction and response.
"""
instruction = root.content
idx = len("# Question: ") + len(root.content) + 1

def find_leaf_nodes(node):
"""Recursively find all leaf nodes."""
if not node.children:
return [node]
leafs = []
for child in node.children:
leafs.extend(find_leaf_nodes(child))
return leafs

# Step 1: Find all leaf nodes
leaf_nodes = find_leaf_nodes(root)

# Step 2: Determine the highest score among leaf nodes
max_value = max(leaf_nodes, key=lambda x: x.value).value

# Step 3: Collect all leaf nodes with the highest score
best_leafs = [leaf for leaf in leaf_nodes if leaf.value == max_value]

# Step 4: Collect trajectories for all the best leaf nodes
best_trajectories = [{"instruction": instruction, "response": leaf.trajectory[idx:]} for leaf in best_leafs]

return best_trajectories

from autogen.agentchat.contrib.reasoning_agent import extract_sft_dataset

# Example usage
sft_data = extract_sft_dataset(reason_agent._root)
Expand All @@ -249,52 +226,7 @@ json.dump(sft_data, open("sft_data.json", "w"), indent=2)
This step generates preference pairs by comparing sibling nodes in the tree. For each parent node with multiple children, we create training pairs where the higher-scored response is marked as preferred over the lower-scored one.

```python
def extract_rlhf_preference_dataset(root, contrastive_threshold=0.2):
"""
Extract and generate preference pairs for RLHF training by comparing sibling nodes.

Args:
root: The root node of the tree.
contrastive_threshold (float): between (0, 1), a distance measure that we are confidence to call
one is positive and another is negative.

Returns:
A list of preference pairs, where each pair contains two responses and
indicates which one is preferred.
"""
preference_pairs = []

assert contrastive_threshold > 0
assert contrastive_threshold < 1

def traverse_tree(node):
"""Traverse the tree to compare sibling nodes and collect preferences."""
if not node.children:
return # Leaf node, no comparisons needed

# Step 1: Compare all sibling nodes
for i in range(len(node.children)):
for j in range(len(node.children)):
if i == j:
continue
child_a, child_b = node.children[i], node.children[j]
if child_a.value - child_b.value > contrastive_threshold:
preference_pairs.append({
"instruction": node.trajectory,
"preferred_response": f"Step {child_a.depth}: {child_a.content}",
"dispreferred_response": f"Step {child_b.depth}: {child_b.content}",
})


# Step 2: Recurse into child nodes
for child in node.children:
traverse_tree(child)

# Start traversal from the root
traverse_tree(root)

return preference_pairs

from autogen.agentchat.contrib.reasoning_agent import extract_rlhf_preference_dataset

# Example usage
rlhf_data = extract_rlhf_preference_dataset(reason_agent._root)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading