From 29f76d8de3c1a31323b65530ac3c587df2840cc8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:03:45 -0400 Subject: [PATCH 1/3] Improve instructions generation for success tool --- src/controlflow/tasks/task.py | 54 ++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 86914d95..52fa6dcf 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -582,9 +582,7 @@ def get_success_tool(self) -> Tool: Create an agent-compatible tool for marking this task as successful. """ options = {} - instructions = unwrap(""" - Use this tool to mark the task as successful and provide a result. - """) + instructions = None result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -605,12 +603,14 @@ def get_success_tool(self) -> Tool: options_str = "\n\n".join( f"Option {i}: {option}" for i, option in serialized_options.items() ) - instructions += "\n\n" + unwrap(""" - Provide a single integer as the result, corresponding to the index + instructions = unwrap( + """ + Provide a single integer as the task result, corresponding to the index of your chosen option. Your options are: {options_str} - """).format(options_str=options_str) + """ + ).format(options_str=options_str) # otherwise try to load the schema for the result type elif self.result_type is not None: @@ -628,6 +628,11 @@ def get_success_tool(self) -> Tool: # for basemodel subclasses, we accept the model properties directly as kwargs if safe_issubclass(result_schema, BaseModel): + instructions = unwrap( + f""" + Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema} + """ + ) def succeed(**kwargs) -> str: self.mark_successful(result=result_schema(**kwargs)) @@ -642,7 +647,13 @@ def succeed(**kwargs) -> str: ) # for all other results, we create a single `result` kwarg to capture the result - else: + elif result_schema is not None: + instructions = unwrap( + f""" + Use this tool to mark the task as successful and provide a result. + The result schema is: {{"task_result": {result_schema}}} + """ + ) @tool( name=f"mark_task_{self.id}_successful", @@ -650,21 +661,34 @@ def succeed(**kwargs) -> str: instructions=instructions, include_return_description=False, ) - def succeed(result: result_schema) -> str: # type: ignore + def succeed(task_result: result_schema) -> str: # type: ignore if self.is_successful(): raise ValueError( f"{self.friendly_name()} is already marked successful." ) if options: - if result not in options: + if task_result not in options: raise ValueError( f"Invalid option. Please choose one of {options}" ) - result = options[result] - self.mark_successful(result=result) + task_result = options[task_result] + self.mark_successful(result=task_result) + return f"{self.friendly_name()} marked successful." + + # for no result schema, we provide a tool that takes no arguments + else: + + @tool( + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + include_return_description=False, + ) + def succeed() -> str: + self.mark_successful() return f"{self.friendly_name()} marked successful." - return succeed + return succeed def get_fail_tool(self) -> Tool: """ @@ -673,8 +697,10 @@ def get_fail_tool(self) -> Tool: @tool( name=f"mark_task_{self.id}_failed", - description=( - f"Mark task {self.id} as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure." + description=unwrap( + f"""Mark task {self.id} as failed. Only use when technical + errors prevent success. Provide a detailed reason for the + failure.""" ), include_return_description=False, ) From 64d41ae6c2101429585216115848ed2649d96118 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:18:13 -0400 Subject: [PATCH 2/3] Ensure all instructions are colelcted --- src/controlflow/tasks/task.py | 49 ++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 52fa6dcf..469388b2 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -582,7 +582,7 @@ def get_success_tool(self) -> Tool: Create an agent-compatible tool for marking this task as successful. """ options = {} - instructions = None + instructions = [] result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -603,14 +603,16 @@ def get_success_tool(self) -> Tool: options_str = "\n\n".join( f"Option {i}: {option}" for i, option in serialized_options.items() ) - instructions = unwrap( - """ - Provide a single integer as the task result, corresponding to the index - of your chosen option. Your options are: - - {options_str} - """ - ).format(options_str=options_str) + instructions.append( + unwrap( + """ + Provide a single integer as the task result, corresponding to the index + of your chosen option. Your options are: + + {options_str} + """ + ).format(options_str=options_str) + ) # otherwise try to load the schema for the result type elif self.result_type is not None: @@ -628,10 +630,12 @@ def get_success_tool(self) -> Tool: # for basemodel subclasses, we accept the model properties directly as kwargs if safe_issubclass(result_schema, BaseModel): - instructions = unwrap( - f""" - Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema} - """ + instructions.append( + unwrap( + f""" + Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema} + """ + ) ) def succeed(**kwargs) -> str: @@ -642,23 +646,25 @@ def succeed(**kwargs) -> str: fn=succeed, name=f"mark_task_{self.id}_successful", description=f"Mark task {self.id} as successful.", - instructions=instructions, + instructions="\n\n".join(instructions) or None, parameters=result_schema.model_json_schema(), ) # for all other results, we create a single `result` kwarg to capture the result elif result_schema is not None: - instructions = unwrap( - f""" - Use this tool to mark the task as successful and provide a result. - The result schema is: {{"task_result": {result_schema}}} - """ + instructions.append( + unwrap( + f""" + Use this tool to mark the task as successful and provide a result with the `task_result` kwarg. + The `task_result` schema is: {{"task_result": {result_schema}}} + """ + ) ) @tool( name=f"mark_task_{self.id}_successful", description=f"Mark task {self.id} as successful.", - instructions=instructions, + instructions="\n\n".join(instructions) or None, include_return_description=False, ) def succeed(task_result: result_schema) -> str: # type: ignore @@ -675,13 +681,14 @@ def succeed(task_result: result_schema) -> str: # type: ignore self.mark_successful(result=task_result) return f"{self.friendly_name()} marked successful." + return succeed # for no result schema, we provide a tool that takes no arguments else: @tool( name=f"mark_task_{self.id}_successful", description=f"Mark task {self.id} as successful.", - instructions=instructions, + instructions="\n\n".join(instructions) or None, include_return_description=False, ) def succeed() -> str: From 279c2c1f01125eca2fb38979642432cb7746470b Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:30:34 -0400 Subject: [PATCH 3/3] Fix kwargs --- tests/tasks/test_tasks.py | 12 ++++++------ tests/utilities/test_testing.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index fa6b2fb7..52480232 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -485,14 +485,14 @@ class TestSuccessTool: def test_success_tool(self): task = Task("choose 5", result_type=int) tool = task.get_success_tool() - tool.run(input=dict(result=5)) + tool.run(input=dict(task_result=5)) assert task.is_successful() assert task.result == 5 def test_success_tool_with_list_of_options(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) tool = task.get_success_tool() - tool.run(input=dict(result=1)) + tool.run(input=dict(task_result=1)) assert task.is_successful() assert task.result == "good" @@ -500,12 +500,12 @@ def test_success_tool_with_list_of_options_requires_int(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) tool = task.get_success_tool() with pytest.raises(ValueError): - tool.run(input=dict(result="good")) + tool.run(input=dict(task_result="good")) def test_tuple_of_ints_result(self): task = Task("choose 5", result_type=(4, 5, 6)) tool = task.get_success_tool() - tool.run(input=dict(result=1)) + tool.run(input=dict(task_result=1)) assert task.result == 5 def test_tuple_of_pydantic_models_result(self): @@ -518,7 +518,7 @@ class Person(BaseModel): result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)), ) tool = task.get_success_tool() - tool.run(input=dict(result=1)) + tool.run(input=dict(task_result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self): def test_manual_success_tool(self): task = Task(objective="Test task", completion_tools=[], result_type=int) success_tool = task.get_success_tool() - success_tool.run(input=dict(result=5)) + success_tool.run(input=dict(task_result=5)) assert task.is_successful() assert task.result == 5 diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index 0d411da2..74ed216b 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm): tool_calls=[ { "name": "mark_task_12345_successful", - "args": {"result": "Hello!"}, + "args": {"task_result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm): assert events[3].event == "tool-result" assert events[3].tool_call == { "name": "mark_task_12345_successful", - "args": {"result": "Hello!"}, + "args": {"task_result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", }