diff --git a/autogen/agentchat/contrib/reasoning_agent.py b/autogen/agentchat/contrib/reasoning_agent.py index edd7427b70..34dd9be2bd 100644 --- a/autogen/agentchat/contrib/reasoning_agent.py +++ b/autogen/agentchat/contrib/reasoning_agent.py @@ -6,8 +6,8 @@ import re from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union -from ..agent import Agent -from ..assistant_agent import AssistantAgent +from autogen.agentchat.agent import Agent +from autogen.agentchat.assistant_agent import AssistantAgent EPSILON = 1e-6 @@ -30,7 +30,7 @@ **Format of Output:** -**Reflection** +REFLECTION: *Give a few sentence reflections on the previous steps, what are wrong and what are good.* **Possible Options:** @@ -161,6 +161,7 @@ def to_dict(self) -> Dict: "content": self.content, "value": self.value, "depth": self.depth, + "reflection": self.reflection, "visits": self.visits, "children": [child.to_dict() for child in self.children], } @@ -180,6 +181,7 @@ def from_dict(cls, data: Dict, parent: Optional["ThinkNode"] = None) -> "ThinkNo node.value = data["value"] node.depth = data["depth"] node.visits = data["visits"] + node.reflection = data.get("reflection", "") # Recursively create children for child_data in data["children"]: @@ -594,9 +596,9 @@ def expand(self, node: ThinkNode) -> List: silent=not self.verbose, ) reply = self.thinker.last_message()["content"].strip() - reflection = re.findall(r"Reflection:(.+?)Possible Options:", reply, re.DOTALL) + reflection = re.findall(r"REFLECTION:\s*(.+?)(?=\*\*Possible Options:\*\*|Option \d+:|$)", reply, re.DOTALL) if reflection: - node.reflection = reflection[0].strip().rstrip() + node.reflection += str(reflection[0].strip()) # Extract options from reply using regex: # - Matches text between "Option N:" and either next "Option N:" or end of string # - (?=...) is a lookahead to match option boundary without including it @@ -623,7 +625,7 @@ def generate_lats_response(self, messages, sender, config=None): # Helper function to determine if we should continue searching def should_continue(node, iteration): - if self._root.is_solved(): + if self._root.is_solved: return False if iteration >= self.lats_max_iterations: return False @@ -659,9 +661,9 @@ def should_continue(node, iteration): candidates = re.findall( r"Option \d+:(.+?)(?=Option \d+:|$)", self.thinker.last_message()["content"].strip(), re.DOTALL ) - for candidate in candidates[: self.lats_num_candidates]: child = ThinkNode(content=candidate.strip(), parent=current) + self.expand(child) # Evaluate candidate and backpropagate reward = self.rate_node(child, ground_truth) child.backpropagate(reward)