Skip to content

Commit

Permalink
debug LATS
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyCNM committed Dec 18, 2024
1 parent 6391356 commit 880ab78
Show file tree
Hide file tree
Showing 4 changed files with 1,366 additions and 875 deletions.
24 changes: 16 additions & 8 deletions autogen/agentchat/contrib/reasoning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,13 @@ def rate_node(self, node: ThinkNode, ground_truth: str = None, is_outcome: bool
message += f"--- Note that the Ground Truth is ---\n{ground_truth}\n---\n"
self._grader.update_system_message(message)

if self._method == "lats":
prompt = self._lats_context + "\n\n---\n\n" + f"Rate:\n{node.trajectory}"
else:
prompt = f"Rate:\n{node.trajectory}"

self.send(
message=f"Rate:\n{node.trajectory}",
message=prompt,
recipient=self._grader,
request_reply=True,
silent=not self._verbose,
Expand Down Expand Up @@ -602,6 +607,8 @@ def _mtcs_reply(self, prompt, ground_truth=""):
self._root = root
answer_nodes = []

self._lats_context = "## Here are some previous trajectories and reflections\n\n" # Store LATS's reflections

# TODO: future, parallelism with Swarm agent or AsyncOpenAI client.
for _ in range(self._nsim):
node = root
Expand All @@ -626,11 +633,6 @@ def _mtcs_reply(self, prompt, ground_truth=""):
while not self._is_terminal(node):
if len(node.children) == 0:
self._expand(node)
if self._method == "lats":
# In LATS: rate the quality of the current child node and
# backpropagate the reward to update the node's value and visits.
reward = self.rate_node(node, ground_truth)
node.backpropagate(reward)
node = random.choice(node.children)

# Add answer (leaf) node and evaluate answer
Expand All @@ -647,7 +649,7 @@ def _mtcs_reply(self, prompt, ground_truth=""):
reward = self.rate_node(_ans_node, ground_truth, is_outcome=True)
_ans_node.value = reward
answer_nodes.append(_ans_node)

self._lats_context += f"### Previous Tries:\n{node.trajectory}\n\nRating:{_ans_node.rating_details}\n\n"
# Backpropagation
node.backpropagate(reward)

Expand All @@ -671,8 +673,14 @@ def _expand(self, node: ThinkNode) -> List:
List[ThinkNode]: A list of new ThinkNode instances created from the options provided by the thinker.
"""
self._thinker.clear_history()

if self._method == "lats":
prompt = self._lats_context + "\n\n---\n\n" + f"{node.trajectory}\n---\nWhat are the possible next steps?"
else:
prompt = f"{node.trajectory}\n---\nWhat are the possible next steps?"

self.send(
message=f"{node.trajectory}\n---\nWhat are the possible next steps?",
message=prompt,
recipient=self._thinker,
request_reply=True,
silent=not self._verbose,
Expand Down
Loading

0 comments on commit 880ab78

Please sign in to comment.