Skip to content

Commit

Permalink
update: added reflection, and few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Hk669 committed Dec 14, 2024
1 parent afc97b4 commit 3d94c74
Show file tree
Hide file tree
Showing 2 changed files with 940 additions and 939 deletions.
24 changes: 12 additions & 12 deletions autogen/agentchat/contrib/reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, content: str, parent: Optional["ThinkNode"] = None) -> None:
self.content = content
self.value = 0
self.parent = parent
self.reflection = ""
self.depth = self.parent.depth + 1 if parent else 0
self.children = []
self.visits = 0
Expand Down Expand Up @@ -305,6 +306,7 @@ def traverse_tree(node):
preference_pairs.append(
{
"instruction": node.trajectory,
"reflection": node.reflection,
"preferred_response": f"Step {child_a.depth}: {child_a.content}",
"dispreferred_response": f"Step {child_b.depth}: {child_b.content}",
}
Expand Down Expand Up @@ -357,6 +359,7 @@ def __init__(

if reason_config:
method = reason_config.get("method", "beam_search")
self.exploration_constant = reason_config.get("exploration_constant", 1.41)
if method == "beam_search":
self.register_reply([Agent, None], ReasoningAgent.generate_beam_response)
if "beam_size" in reason_config:
Expand All @@ -366,12 +369,12 @@ def __init__(
elif method == "mcts":
self.register_reply([Agent, None], ReasoningAgent.generate_mcts_response)
self.mcts_simulations = reason_config.get("nsim", 10)
self.exploration_constant = reason_config.get("exploration_constant", 1.41)
elif method == "lats":
self.register_reply([Agent, None], ReasoningAgent.generate_lats_response)
self.lats_max_iterations = reason_config.get("max_iterations", 5)
self.lats_num_candidates = reason_config.get("num_candidates", 3)

else:
raise ValueError("Reasoning method not specified in `reason_config`.")
self._root = None

def rate_node(self, node: ThinkNode, ground_truth: str = None) -> float:
Expand Down Expand Up @@ -400,7 +403,7 @@ def rate_node(self, node: ThinkNode, ground_truth: str = None) -> float:
rating = self.grader.last_message()["content"].strip()
try:
# Scale rating to [0, 1]
reward = (float(re.findall(r"[\d.]+", rating)[0]) - 1) / 4.0
reward = (float(re.findall(r"[\d.]+", rating)[0]) - 1) / 9.0
except (IndexError, ValueError):
reward = 0.0 # Default reward if parsing fails
return reward
Expand Down Expand Up @@ -562,13 +565,7 @@ def generate_mcts_response(self, messages, sender, config=None):
answer_nodes.append(_ans_node)

# Backpropagation
while node is not None:
node.visits += 1
if node.value is None:
node.value = reward
else:
node.value += reward
node = node.parent
node.backpropagate(reward)

# Best action
best_ans_node = max(answer_nodes, key=lambda node: node.value)
Expand Down Expand Up @@ -597,7 +594,9 @@ def expand(self, node: ThinkNode) -> List:
silent=not self.verbose,
)
reply = self.thinker.last_message()["content"].strip()

reflection = re.findall(r"Reflection:(.+?)Possible Options:", reply, re.DOTALL)
if reflection:
node.reflection = reflection[0].strip().rstrip()
# Extract options from reply using regex:
# - Matches text between "Option N:" and either next "Option N:" or end of string
# - (?=...) is a lookahead to match option boundary without including it
Expand Down Expand Up @@ -641,7 +640,8 @@ def should_continue(node, iteration):
# Use UCT formula similar to MCTS
choices_weights = [
(child.value / (child.visits + EPSILON))
+ 1.41 * math.sqrt(math.log(current.visits + EPSILON) / (child.visits + EPSILON))
+ self.exploration_constant
* math.sqrt(math.log(current.visits + EPSILON) / (child.visits + EPSILON))
for child in current.children
]
current = current.children[choices_weights.index(max(choices_weights))]
Expand Down
Loading

0 comments on commit 3d94c74

Please sign in to comment.