Skip to content

Commit

Permalink
mcts init
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyCNM committed Dec 10, 2024
1 parent 6b52889 commit dfda100
Show file tree
Hide file tree
Showing 4 changed files with 426 additions and 112 deletions.
257 changes: 231 additions & 26 deletions autogen/agentchat/contrib/reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from ..agent import Agent
from ..assistant_agent import AssistantAgent
from ..assistant_agent import AssistantAgent
import random
import math

EPSILON = 1e-6


TreeofThought_message = """
Role: Expert Planning AI Assistant
Expand Down Expand Up @@ -63,11 +68,11 @@ def __init__(self, content: str, parent: Optional["ThinkNode"] = None) -> None:
- Providing trajectory utilities to get the full path from root to this node
"""
self.content = content
self.value = None
self.value = 0
self.parent = parent
self.depth = self.parent.depth + 1 if parent else 0
self.children = []
self.visits = 0 # TODO: remove this line if not used.
self.visits = 0
if self.parent:
self.parent.children.append(self)

Expand Down Expand Up @@ -175,18 +180,110 @@ def add_nodes(node: ThinkNode, node_id: str = "0"):
print("Make sure graphviz is installed on your system: https://graphviz.org/download/")



def extract_sft_dataset(root):
"""
Extract the best trajectory or multiple equally good trajectories
for SFT training.
Args:
root: The root node of the tree.
Returns:
List of best trajectories, where each trajectory is a pair of instruction and response.
"""
instruction = root.content
idx = len("# Question: ") + len(root.content) + 1

def _find_leaf_nodes(node):
"""Recursively find all leaf nodes."""
if not node.children:
return [node]
leafs = []
for child in node.children:
leafs.extend(_find_leaf_nodes(child))
return leafs

# Step 1: Find all leaf nodes
leaf_nodes = _find_leaf_nodes(root)

# Step 2: Determine the highest score among leaf nodes
max_value = max(leaf_nodes, key=lambda x: x.value).value

# Step 3: Collect all leaf nodes with the highest score
best_leafs = [leaf for leaf in leaf_nodes if leaf.value == max_value]

# Step 4: Collect trajectories for all the best leaf nodes
best_trajectories = [{"instruction": instruction, "response": leaf.trajectory[idx:]} for leaf in best_leafs]

return best_trajectories


def extract_rlhf_preference_dataset(root, contrastive_threshold=0.2):
"""
Extract and generate preference pairs for RLHF training by comparing sibling nodes.
Args:
root: The root node of the tree.
contrastive_threshold (float): between (0, 1), a distance measure that we are confidence to call
one is positive and another is negative.
Returns:
A list of preference pairs, where each pair contains two responses and
indicates which one is preferred.
"""
preference_pairs = []

assert contrastive_threshold > 0
assert contrastive_threshold < 1

def traverse_tree(node):
"""Traverse the tree to compare sibling nodes and collect preferences."""
if not node.children:
return # Leaf node, no comparisons needed

# Step 1: Compare all sibling nodes
for i in range(len(node.children)):
for j in range(len(node.children)):
if i == j:
continue
child_a, child_b = node.children[i], node.children[j]

is_a_better = False
if child_a.visits > 0 and child_b.visits > 0:
# for MCTS
is_a_better = child_a.value / child_a.visits - child_b.value / child_b.visits > contrastive_threshold
else:
# for Beam Search
is_a_better = child_a.value - child_b.value > contrastive_threshold
if is_a_better:
preference_pairs.append({
"instruction": node.trajectory,
"preferred_response": f"Step {child_a.depth}: {child_a.content}",
"dispreferred_response": f"Step {child_b.depth}: {child_b.content}",
})

# Step 2: Recurse into child nodes
for child in node.children:
traverse_tree(child)

# Start traversal from the root
traverse_tree(root)

return preference_pairs

class ReasoningAgent(AssistantAgent):
def __init__(
self, name, llm_config, max_depth=4, beam_size=3, answer_approach="pool", verbose=True, **kwargs
self, name, llm_config, max_depth=4, beam_size=3, answer_approach="pool", verbose=True, reason_config: dict=None, **kwargs
) -> None:
"""Initialize a ReasoningAgent that uses tree-of-thought reasoning.,
Args:
name: Name of the agent
llm_config: Configuration for the language model
max_depth (int): Maximum depth of the reasoning tree
beam_size (int): Number of parallel reasoning paths to maintain
answer_approach (str): Either "pool" or "best" - how to generate final answer
beam_size (int): DEPRECATED. Number of parallel reasoning paths to maintain
answer_approach (str): DEPRECATED. Either "pool" or "best" - how to generate final answer
verbose (bool): Whether to show intermediate steps
"""
super().__init__(name=name, llm_config=llm_config, **kwargs)
Expand All @@ -202,7 +299,19 @@ def __init__(
system_message="Rate the thinking trajectories for score 1 - 5 (1: worst, 5: best).",
llm_config=llm_config,
)
self.register_reply([Agent, None], ReasoningAgent.generate_response)

if reason_config:
method = reason_config.get("method", "beam_search")
if method == "beam_search":
self.register_reply([Agent, None], ReasoningAgent.generate_beam_response)
if "beam_size" in reason_config:
self.beam_size = reason_config["beam_size"]
if "answer_approach" in reason_config:
self.answer_approach = reason_config["answer_approach"]
elif method == "mcts":
self.register_reply([Agent, None], ReasoningAgent.generate_mcts_response)
self.mcts_simulations = reason_config.get("nsim", 10)
self.exploration_constant = reason_config.get("exploration_constant", 1.41)

self._root = None

Expand All @@ -216,7 +325,8 @@ def rate_node(self, node: ThinkNode) -> float:
float: Normalized score between 0 and 1 indicating trajectory quality
"""
self.send(
message=f"Rate the trajectory:\n{node.trajectory}", recipient=self.grader, request_reply=True, silent=False
message=f"Rate:\n{node.trajectory}", recipient=self.grader, request_reply=True,
silent=not self.verbose,
)
rating = self.grader.last_message()["content"].strip()
try:
Expand All @@ -226,7 +336,7 @@ def rate_node(self, node: ThinkNode) -> float:
reward = 0.0 # Default reward if parsing fails
return reward

def generate_response(self, messages, sender, config=None):
def generate_beam_response(self, messages, sender, config=None):
"""Generate a response using tree-of-thought reasoning.
Implements beam search through a tree of reasoning steps, using the thinker
Expand Down Expand Up @@ -257,29 +367,14 @@ def generate_response(self, messages, sender, config=None):
while prev_leafs and len(final_answers) < self.beam_size:
new_leafs = []
for node in prev_leafs:
if (self.max_depth and node.depth >= self.max_depth) or "TERMINATE" in node.content:
if self.is_terminal(node):
# Reached max depth; collect possible answers
if node.value is None:
node.value = self.rate_node(node)
final_answers.add(node)
continue

self.thinker.clear_history()
self.send(
message=f"{node.trajectory}\n---\nWhat are the possible next steps?",
recipient=self.thinker,
request_reply=True,
silent=False,
)
reply = self.thinker.last_message()["content"].strip()

options = re.findall(
r"Option \d+:(.+?)(?=Option \d+:|$)", reply, re.DOTALL
) # the options that the thinker provides
for option in options:
new_leafs.append(
ThinkNode(content=option.strip().rstrip(), parent=node)
) # each option is a new leaf node
new_leafs += self.expand(node)

prev_leafs = new_leafs

Expand Down Expand Up @@ -321,3 +416,113 @@ def generate_response(self, messages, sender, config=None):

final_answer = self.chat_messages[self][-1]["content"].strip()
return True, final_answer

def generate_mcts_response(self, messages, sender, config=None):
if sender == self:
return False, "" # Defer the LLM call to next reply functions.

messages = self._oai_messages[sender] if messages is None else messages
prompt = messages[-1]["content"].strip()
if not prompt:
return True, "TERMINATE"

# Extract the ground truth for more accurate evaluation.
# TODO: in the future, allow user to pass a callable (func) to calculate reward.
if "GROUND_TRUTH" in prompt:
idx = prompt.find("GROUND_TRUTH")
prompt, ground_truth = prompt[:idx].rstrip(), prompt[idx:]
else:
ground_truth = None

root = ThinkNode(content=prompt, parent=None)
self._root = root
answer_nodes = []

# TODO: future, parallelism with Swarm agent or AsyncOpenAI client.
for _ in range(self.mcts_simulations):
node = root

# Selection
while not self.is_terminal(node) and len(node.children) > 0:
choices_weights = [
# exploitation term +
(child.value / (child.visits + EPSILON)) +
# exploration term
self.exploration_constant * math.sqrt((2 * math.log(node.visits + EPSILON) / (child.visits + EPSILON)))
for child in node.children
]
node = node.children[choices_weights.index(max(choices_weights))]

# Expansion and Simulation
while not self.is_terminal(node):
if len(node.children) == 0:
self.expand(node)
node = random.choice(node.children)

# Add answer (leaf) node and evaluate answer
self.send(
message=f"Answer the question {prompt}. Here is my thinking process:\n{node.trajectory}",
recipient=self,
request_reply=True,
silent=not self.verbose)
_answer = self.last_message(self)["content"].strip()
# We add the answer (as a node) to the leaf to help
# future logging and debugging.
_ans_node = ThinkNode(content=_answer, parent=node)
if ground_truth:
# override the system message
self.grader.update_system_message(f"Rate the answer for score 1 - 5 (1: worst, 5: best). The Ground Truth is:\n{ground_truth}")

reward = self.rate_node(_ans_node)
_ans_node.value = reward
answer_nodes.append(_ans_node)

# Backpropagation
while node is not None:
node.visits += 1
if node.value is None:
node.value = reward
else:
node.value += reward
node = node.parent

# Best action
best_ans_node = max(answer_nodes, key=lambda node: node.value)
return True, best_ans_node.content


def expand(self, node: ThinkNode) -> List:
"""
Expand the node by generating possible next steps based on the current trajectory.
This method sends a message to the thinker agent, asking for possible next steps
that can be taken from the current node's trajectory. It processes the response to
extract the options provided by the thinker and creates new ThinkNode instances
for each option.
Args:
node (ThinkNode): The node to expand, representing the current state in the reasoning process.
Returns:
List[ThinkNode]: A list of new ThinkNode instances created from the options provided by the thinker.
"""
self.thinker.clear_history()
self.send(
message=f"{node.trajectory}\n---\nWhat are the possible next steps?",
recipient=self.thinker,
request_reply=True,
silent=not self.verbose)
reply = self.thinker.last_message()["content"].strip()

# Extract options from reply using regex:
# - Matches text between "Option N:" and either next "Option N:" or end of string
# - (?=...) is a lookahead to match option boundary without including it
# - re.DOTALL allows . to match newlines
options = re.findall(r"Option \d+:(.+?)(?=Option \d+:|$)", reply, re.DOTALL)

return [ThinkNode(content=option.strip().rstrip(), parent=node) for option in options]


def is_terminal(self, node):
return node.depth >= self.max_depth or "TERMINATE" in node.content

4 changes: 2 additions & 2 deletions notebook/tree_of_thoughts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit dfda100

Please sign in to comment.