Skip to content

Commit

Permalink
Merge pull request #372 from PrefectHQ/success-tool-generation
Browse files Browse the repository at this point in the history
Improve instructions generation for success tool
  • Loading branch information
jlowin authored Oct 31, 2024
2 parents 973d663 + 279c2c1 commit ae59a63
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 26 deletions.
69 changes: 51 additions & 18 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,7 @@ def get_success_tool(self) -> Tool:
Create an agent-compatible tool for marking this task as successful.
"""
options = {}
instructions = unwrap("""
Use this tool to mark the task as successful and provide a result.
""")
instructions = []
result_schema = None

# if the result_type is a tuple of options, then we want the LLM to provide
Expand All @@ -605,12 +603,16 @@ def get_success_tool(self) -> Tool:
options_str = "\n\n".join(
f"Option {i}: {option}" for i, option in serialized_options.items()
)
instructions += "\n\n" + unwrap("""
Provide a single integer as the result, corresponding to the index
of your chosen option. Your options are:
{options_str}
""").format(options_str=options_str)
instructions.append(
unwrap(
"""
Provide a single integer as the task result, corresponding to the index
of your chosen option. Your options are:
{options_str}
"""
).format(options_str=options_str)
)

# otherwise try to load the schema for the result type
elif self.result_type is not None:
Expand All @@ -628,6 +630,13 @@ def get_success_tool(self) -> Tool:

# for basemodel subclasses, we accept the model properties directly as kwargs
if safe_issubclass(result_schema, BaseModel):
instructions.append(
unwrap(
f"""
Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema}
"""
)
)

def succeed(**kwargs) -> str:
self.mark_successful(result=result_schema(**kwargs))
Expand All @@ -637,34 +646,56 @@ def succeed(**kwargs) -> str:
fn=succeed,
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
instructions="\n\n".join(instructions) or None,
parameters=result_schema.model_json_schema(),
)

# for all other results, we create a single `result` kwarg to capture the result
else:
elif result_schema is not None:
instructions.append(
unwrap(
f"""
Use this tool to mark the task as successful and provide a result with the `task_result` kwarg.
The `task_result` schema is: {{"task_result": {result_schema}}}
"""
)
)

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
instructions="\n\n".join(instructions) or None,
include_return_description=False,
)
def succeed(result: result_schema) -> str: # type: ignore
def succeed(task_result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if result not in options:
if task_result not in options:
raise ValueError(
f"Invalid option. Please choose one of {options}"
)
result = options[result]
self.mark_successful(result=result)
task_result = options[task_result]
self.mark_successful(result=task_result)
return f"{self.friendly_name()} marked successful."

return succeed
# for no result schema, we provide a tool that takes no arguments
else:

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions="\n\n".join(instructions) or None,
include_return_description=False,
)
def succeed() -> str:
self.mark_successful()
return f"{self.friendly_name()} marked successful."

return succeed

def get_fail_tool(self) -> Tool:
"""
Expand All @@ -673,8 +704,10 @@ def get_fail_tool(self) -> Tool:

@tool(
name=f"mark_task_{self.id}_failed",
description=(
f"Mark task {self.id} as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure."
description=unwrap(
f"""Mark task {self.id} as failed. Only use when technical
errors prevent success. Provide a detailed reason for the
failure."""
),
include_return_description=False,
)
Expand Down
12 changes: 6 additions & 6 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,27 +485,27 @@ class TestSuccessTool:
def test_success_tool(self):
task = Task("choose 5", result_type=int)
tool = task.get_success_tool()
tool.run(input=dict(result=5))
tool.run(input=dict(task_result=5))
assert task.is_successful()
assert task.result == 5

def test_success_tool_with_list_of_options(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
tool.run(input=dict(result=1))
tool.run(input=dict(task_result=1))
assert task.is_successful()
assert task.result == "good"

def test_success_tool_with_list_of_options_requires_int(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
with pytest.raises(ValueError):
tool.run(input=dict(result="good"))
tool.run(input=dict(task_result="good"))

def test_tuple_of_ints_result(self):
task = Task("choose 5", result_type=(4, 5, 6))
tool = task.get_success_tool()
tool.run(input=dict(result=1))
tool.run(input=dict(task_result=1))
assert task.result == 5

def test_tuple_of_pydantic_models_result(self):
Expand All @@ -518,7 +518,7 @@ class Person(BaseModel):
result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)),
)
tool = task.get_success_tool()
tool.run(input=dict(result=1))
tool.run(input=dict(task_result=1))
assert task.result == Person(name="Bob", age=35)
assert isinstance(task.result, Person)

Expand Down Expand Up @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self):
def test_manual_success_tool(self):
task = Task(objective="Test task", completion_tools=[], result_type=int)
success_tool = task.get_success_tool()
success_tool.run(input=dict(result=5))
success_tool.run(input=dict(task_result=5))
assert task.is_successful()
assert task.result == 5

Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"result": "Hello!"},
"args": {"task_result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand All @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm):
assert events[3].event == "tool-result"
assert events[3].tool_call == {
"name": "mark_task_12345_successful",
"args": {"result": "Hello!"},
"args": {"task_result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down

0 comments on commit ae59a63

Please sign in to comment.