Skip to content

Commit

Permalink
Deprecation warning and notebook update
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyCNM committed Dec 18, 2024
1 parent 5690a7e commit d25359c
Show file tree
Hide file tree
Showing 6 changed files with 5,062 additions and 5,558 deletions.
59 changes: 45 additions & 14 deletions autogen/agentchat/contrib/reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import random
import re
import warnings
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from ..agent import Agent
Expand Down Expand Up @@ -310,23 +311,50 @@ def __init__(
Args:
name: Name of the agent
llm_config: Configuration for the language model
grader_llm_config: Optional separate configuration for the grader model. If not provided, uses llm_config
max_depth (int): Maximum depth of the reasoning tree
beam_size (int): DEPRECATED. Number of parallel reasoning paths to maintain
answer_approach (str): DEPRECATED. Either "pool" or "best" - how to generate final answer
verbose (bool): Whether to show intermediate steps
reason_config (dict): Configuration for the reasoning method, e.g.,
{"method": "mcts"} or
{"method": "beam_search", "beam_size": 3, "answer_approach": "pool"} or
{"method": "lats", "max_iterations": 10, "num_candidates": 5}
reason_config (dict): Configuration for the reasoning method. Supported parameters:
method (str): The search strategy to use. Options:
- "beam_search" (default): Uses beam search with parallel paths
- "mcts": Uses Monte Carlo Tree Search for exploration
- "lats": Uses Language Agent Tree Search with per-step rewards
- "dfs": Uses depth-first search (equivalent to beam_search with beam_size=1)
Common parameters:
max_depth (int): Maximum depth of reasoning tree (default: 3)
forest_size (int): Number of independent trees to maintain (default: 1)
rating_scale (int): Scale for grading responses, e.g. 1-10 (default: 10)
Beam Search specific:
beam_size (int): Number of parallel paths to maintain (default: 3)
answer_approach (str): How to select final answer, "pool" or "best" (default: "pool")
MCTS/LATS specific:
nsim (int): Number of simulations to run (default: 3)
exploration_constant (float): UCT exploration parameter (default: 1.41)
Example configs:
{"method": "beam_search", "beam_size": 5, "max_depth": 4}
{"method": "mcts", "nsim": 10, "exploration_constant": 2.0}
{"method": "lats", "nsim": 5, "forest_size": 3}
"""
super().__init__(name=name, llm_config=llm_config, **kwargs)
self._max_depth = max_depth
self._beam_size = beam_size
self._verbose = verbose
self._answer_approach = answer_approach
self._llm_config = llm_config
self._grader_llm_config = grader_llm_config if grader_llm_config else llm_config

if max_depth != 4 or beam_size != 3 or answer_approach != "pool":
# deprecate warning
warnings.warn(
"The parameters max_depth, beam_size, and answer_approach have been deprecated. "
"Please use the reason_config dictionary to configure these settings instead.",
DeprecationWarning,
)

if reason_config is None:
reason_config = {}
self._reason_config = reason_config
Expand All @@ -336,13 +364,14 @@ def __init__(
if self._method == "dfs":
self._beam_size = 1
else:
self._beam_size = reason_config.get("beam_size", 3)
self._answer_approach = reason_config.get("answer_approach", "pool")
self._beam_size = reason_config.get("beam_size", beam_size)
self._answer_approach = reason_config.get("answer_approach", answer_approach)
assert self._answer_approach in ["pool", "best"]
elif self._method in ["mcts", "lats"]:
self._nsim = reason_config.get("nsim", 3)
self._exploration_constant = reason_config.get("exploration_constant", 1.41)

self._max_depth = reason_config.get("max_depth", max_depth)
self._forest_size = reason_config.get("forest_size", 1) # We default use only 1 tree.
self._rating_scale = reason_config.get("rating_scale", 10)

Expand Down Expand Up @@ -374,10 +403,12 @@ def generate_forest_response(self, messages, sender, config=None):

forest_answers = []
for _ in range(self._forest_size):
if self._method == "beam_search":
success, response = self._beam_reply(prompt, ground_truth)
if self._method in ["beam_search", "dfs"]:
response = self._beam_reply(prompt, ground_truth)
elif self._method in ["mcts", "lats"]:
success, response = self._mtcs_reply(prompt, ground_truth)
response = self._mtcs_reply(prompt, ground_truth)
else:
raise ValueError("Invalid reasoning method specified.")

forest_answers.append(response)

Expand Down Expand Up @@ -564,7 +595,7 @@ def _beam_reply(self, prompt, ground_truth=""):
)

final_answer = self.chat_messages[self][-1]["content"].strip()
return True, final_answer
return final_answer

def _mtcs_reply(self, prompt, ground_truth=""):
root = ThinkNode(content=prompt, parent=None)
Expand Down Expand Up @@ -622,7 +653,7 @@ def _mtcs_reply(self, prompt, ground_truth=""):

# Best action
best_ans_node = max(answer_nodes, key=lambda node: node.value)
return True, best_ans_node.content
return best_ans_node.content

def _expand(self, node: ThinkNode) -> List:
"""
Expand Down
Loading

0 comments on commit d25359c

Please sign in to comment.