Skip to content

Commit

Permalink
fix: reflection addition through expand
Browse files Browse the repository at this point in the history
  • Loading branch information
Hk669 committed Dec 14, 2024
1 parent 3d94c74 commit dd644bc
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions autogen/agentchat/contrib/reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import re
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from ..agent import Agent
from ..assistant_agent import AssistantAgent
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent

EPSILON = 1e-6

Expand All @@ -30,7 +30,7 @@
**Format of Output:**
**Reflection**
REFLECTION:
*Give a few sentence reflections on the previous steps, what are wrong and what are good.*
**Possible Options:**
Expand Down Expand Up @@ -161,6 +161,7 @@ def to_dict(self) -> Dict:
"content": self.content,
"value": self.value,
"depth": self.depth,
"reflection": self.reflection,
"visits": self.visits,
"children": [child.to_dict() for child in self.children],
}
Expand All @@ -180,6 +181,7 @@ def from_dict(cls, data: Dict, parent: Optional["ThinkNode"] = None) -> "ThinkNo
node.value = data["value"]
node.depth = data["depth"]
node.visits = data["visits"]
node.reflection = data.get("reflection", "")

# Recursively create children
for child_data in data["children"]:
Expand Down Expand Up @@ -594,9 +596,9 @@ def expand(self, node: ThinkNode) -> List:
silent=not self.verbose,
)
reply = self.thinker.last_message()["content"].strip()
reflection = re.findall(r"Reflection:(.+?)Possible Options:", reply, re.DOTALL)
reflection = re.findall(r"REFLECTION:\s*(.+?)(?=\*\*Possible Options:\*\*|Option \d+:|$)", reply, re.DOTALL)
if reflection:
node.reflection = reflection[0].strip().rstrip()
node.reflection += str(reflection[0].strip())
# Extract options from reply using regex:
# - Matches text between "Option N:" and either next "Option N:" or end of string
# - (?=...) is a lookahead to match option boundary without including it
Expand All @@ -623,7 +625,7 @@ def generate_lats_response(self, messages, sender, config=None):

# Helper function to determine if we should continue searching
def should_continue(node, iteration):
if self._root.is_solved():
if self._root.is_solved:
return False
if iteration >= self.lats_max_iterations:
return False
Expand Down Expand Up @@ -659,9 +661,9 @@ def should_continue(node, iteration):
candidates = re.findall(
r"Option \d+:(.+?)(?=Option \d+:|$)", self.thinker.last_message()["content"].strip(), re.DOTALL
)

for candidate in candidates[: self.lats_num_candidates]:
child = ThinkNode(content=candidate.strip(), parent=current)
self.expand(child)
# Evaluate candidate and backpropagate
reward = self.rate_node(child, ground_truth)
child.backpropagate(reward)
Expand Down

0 comments on commit dd644bc

Please sign in to comment.