diff --git a/docs/source/notebooks/tool_usage.ipynb b/docs/source/notebooks/tool_usage.ipynb index 51091058..f4c8df75 100644 --- a/docs/source/notebooks/tool_usage.ipynb +++ b/docs/source/notebooks/tool_usage.ipynb @@ -43,28 +43,17 @@ "text/html": [ "\n", "\n", - "\n", + "\n", "\n", "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", "\n", "
IDName Dataset ID Description
Name Type Dataset ID Description
0Tool Usage - Relational Data e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5Environment with fake data about users and their locations and favorite foods.\n", + "
Tool Usage - Typewriter (1 func)ToolUsageTask placeholder Environment with a single function that accepts a single letter as input, and "prints" it on a piece of paper.\n", "\n", - "The environment provides a set of tools that can be used to query the data.\n", - "\n", - "The objective of this task is to evaluate the ability to use the provided tools to answer questions about relational data.\n", - "\n", - "The dataset contains 21 examples of varying difficulty. The difficulty is measured by the number of tools that need to be used to answer the question.\n", - "\n", - "Each example is composed of a question, a reference answer, and information about the sequence in which tools should be used to answer the question.\n", - "\n", - "Success is measured by the ability to answer the question correctly, and efficiently.
1Tool Usage - Typewriter (1 func)placeholder Environment with a single function that accepts a single letter as input, and "prints" it on a piece of paper.\n", - "\n", - "The objective of this task is to evaluate the ability to use the provided tools to repeat a given input string.\n", + "The objective of this task is to evaluate the ability to use the provided tools to repeat a given input string.\n", "\n", "For example, if the string is 'abc', the tools 'a', 'b', and 'c' must be invoked in that order.\n", "\n", "The dataset includes examples of varying difficulty. The difficulty is measured by the length of the string.
2Tool Usage - Typewriter placeholder Environment with 26 functions each representing a letter of the alphabet.\n", + "
Tool Usage - Typewriter ToolUsageTask placeholder Environment with 26 functions each representing a letter of the alphabet.\n", "\n", "In this variation of the typewriter task, there are 26 parameterless functions, where each function represents a letter of the alphabet (instead of a single function that takes a letter as an argument).\n", "\n", @@ -73,16 +62,32 @@ "For example, if the string is 'abc', the tools 'a', 'b', and 'c' must be invoked in that order.\n", "\n", "The dataset includes examples of varying difficulty. The difficulty is measured by the length of the string.
3Multiverse Math placeholder An environment that contains a few basic math operations, but with altered results.\n", + "
Tool Usage - Relational Data ToolUsageTask e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5 Environment with fake data about users and their locations and favorite foods.\n", + "\n", + "The environment provides a set of tools that can be used to query the data.\n", + "\n", + "The objective of this task is to evaluate the ability to use the provided tools to answer questions about relational data.\n", + "\n", + "The dataset contains 21 examples of varying difficulty. The difficulty is measured by the number of tools that need to be used to answer the question.\n", + "\n", + "Each example is composed of a question, a reference answer, and information about the sequence in which tools should be used to answer the question.\n", + "\n", + "Success is measured by the ability to answer the question correctly, and efficiently.
Multiverse Math ToolUsageTask placeholder An environment that contains a few basic math operations, but with altered results.\n", "\n", "For example, multiplication of 5*3 will be re-interpreted as 5*3*1.1. The basic operations retain some basic properties, such as commutativity, associativity, and distributivity; however, the results are different than expected.\n", "\n", "The objective of this task is to evaluate the ability to use the provided tools to solve simple math questions and ignore any innate knowledge about math.
Email Extraction ExtractionTaskhttps://smith.langchain.com/public/36bdfe7d-3cd1-4b36-b957-d12d95810a2b/dA dataset of 42 real emails deduped from a spam folder, with semantic HTML tags removed, as well as a script for initial extraction and formatting of other emails from an arbitrary .mbox file like the one exported by Gmail.\n", + "\n", + "Some additional cleanup of the data was done by hand after the initial pass.\n", + "\n", + "See https://github.com/jacoblee93/oss-model-extraction-evals.
" ], "text/plain": [ - "Registry(tasks=[Task(id=0, name='Tool Usage - Relational Data', dataset_id='e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5', create_environment=, description='Environment with fake data about users and their locations and favorite foods.\\n\\nThe environment provides a set of tools that can be used to query the data.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to answer questions about relational data.\\n\\nThe dataset contains 21 examples of varying difficulty. The difficulty is measured by the number of tools that need to be used to answer the question.\\n\\nEach example is composed of a question, a reference answer, and information about the sequence in which tools should be used to answer the question.\\n\\nSuccess is measured by the ability to answer the question correctly, and efficiently.\\n', instructions=\"Please answer the user's question by using the tools provided. Do not guess the answer. Keep in mind that entities like users,foods and locations have both a name and an ID, which are not the same.\"), Task(id=1, name='Tool Usage - Typewriter (1 func)', dataset_id='placeholder', create_environment=, description='Environment with a single function that accepts a single letter as input, and \"prints\" it on a piece of paper.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to repeat a given input string.\\n\\nFor example, if the string is \\'abc\\', the tools \\'a\\', \\'b\\', and \\'c\\' must be invoked in that order.\\n\\nThe dataset includes examples of varying difficulty. The difficulty is measured by the length of the string.\\n', instructions=\"Repeat the given string by using the provided tools. Do not write anything else or provide any explanations. For example, if the string is 'abc', you must invoke the tools 'a', 'b', and 'c' in that order. Please invoke the function with a single letter at a time.\"), Task(id=2, name='Tool Usage - Typewriter', dataset_id='placeholder', create_environment=, description=\"Environment with 26 functions each representing a letter of the alphabet.\\n\\nIn this variation of the typewriter task, there are 26 parameterless functions, where each function represents a letter of the alphabet (instead of a single function that takes a letter as an argument).\\n\\nThe object is to evaluate the ability of use the functions to repeat the given string.\\n\\nFor example, if the string is 'abc', the tools 'a', 'b', and 'c' must be invoked in that order.\\n\\nThe dataset includes examples of varying difficulty. The difficulty is measured by the length of the string.\\n\", instructions=\"Repeat the given string by using the provided tools. Do not write anything else or provide any explanations. For example, if the string is 'abc', you must invoke the tools 'a', 'b', and 'c' in that order. Please invoke the functions without any arguments.\"), Task(id=3, name='Multiverse Math', dataset_id='placeholder', create_environment=, description='An environment that contains a few basic math operations, but with altered results.\\n\\nFor example, multiplication of 5*3 will be re-interpreted as 5*3*1.1. The basic operations retain some basic properties, such as commutativity, associativity, and distributivity; however, the results are different than expected.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to solve simple math questions and ignore any innate knowledge about math.\\n', instructions='You are requested to solve math questions in an alternate mathematical universe. The rules of association, commutativity, and distributivity still apply, but the operations have been altered to yield different results than expected. Solve the given math questions using the provided tools. Do not guess the answer.')])" + "Registry(tasks=[ToolUsageTask(name='Tool Usage - Typewriter (1 func)', dataset_id='placeholder', description='Environment with a single function that accepts a single letter as input, and \"prints\" it on a piece of paper.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to repeat a given input string.\\n\\nFor example, if the string is \\'abc\\', the tools \\'a\\', \\'b\\', and \\'c\\' must be invoked in that order.\\n\\nThe dataset includes examples of varying difficulty. The difficulty is measured by the length of the string.\\n', create_environment=, instructions=\"Repeat the given string by using the provided tools. Do not write anything else or provide any explanations. For example, if the string is 'abc', you must invoke the tools 'a', 'b', and 'c' in that order. Please invoke the function with a single letter at a time.\"), ToolUsageTask(name='Tool Usage - Typewriter', dataset_id='placeholder', description=\"Environment with 26 functions each representing a letter of the alphabet.\\n\\nIn this variation of the typewriter task, there are 26 parameterless functions, where each function represents a letter of the alphabet (instead of a single function that takes a letter as an argument).\\n\\nThe object is to evaluate the ability of use the functions to repeat the given string.\\n\\nFor example, if the string is 'abc', the tools 'a', 'b', and 'c' must be invoked in that order.\\n\\nThe dataset includes examples of varying difficulty. The difficulty is measured by the length of the string.\\n\", create_environment=, instructions=\"Repeat the given string by using the provided tools. Do not write anything else or provide any explanations. For example, if the string is 'abc', you must invoke the tools 'a', 'b', and 'c' in that order. Please invoke the functions without any arguments.\"), ToolUsageTask(name='Tool Usage - Relational Data', dataset_id='e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5', description='Environment with fake data about users and their locations and favorite foods.\\n\\nThe environment provides a set of tools that can be used to query the data.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to answer questions about relational data.\\n\\nThe dataset contains 21 examples of varying difficulty. The difficulty is measured by the number of tools that need to be used to answer the question.\\n\\nEach example is composed of a question, a reference answer, and information about the sequence in which tools should be used to answer the question.\\n\\nSuccess is measured by the ability to answer the question correctly, and efficiently.\\n', create_environment=, instructions=\"Please answer the user's question by using the tools provided. Do not guess the answer. Keep in mind that entities like users,foods and locations have both a name and an ID, which are not the same.\"), ToolUsageTask(name='Multiverse Math', dataset_id='placeholder', description='An environment that contains a few basic math operations, but with altered results.\\n\\nFor example, multiplication of 5*3 will be re-interpreted as 5*3*1.1. The basic operations retain some basic properties, such as commutativity, associativity, and distributivity; however, the results are different than expected.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to solve simple math questions and ignore any innate knowledge about math.\\n', create_environment=, instructions='You are requested to solve math questions in an alternate mathematical universe. The rules of association, commutativity, and distributivity still apply, but the operations have been altered to yield different results than expected. Solve the given math questions using the provided tools. Do not guess the answer.'), ExtractionTask(name='Email Extraction', dataset_id='https://smith.langchain.com/public/36bdfe7d-3cd1-4b36-b957-d12d95810a2b/d', description='A dataset of 42 real emails deduped from a spam folder, with semantic HTML tags removed, as well as a script for initial extraction and formatting of other emails from an arbitrary .mbox file like the one exported by Gmail.\\n\\nSome additional cleanup of the data was done by hand after the initial pass.\\n\\nSee https://github.com/jacoblee93/oss-model-extraction-evals.\\n ', model=)])" ] }, "execution_count": 2, @@ -107,8 +112,8 @@ "text/html": [ "\n", "\n", - "\n", "\n", + "\n", "\n", "
ID 0
Name Tool Usage - Relational Data
Type ToolUsageTask
Dataset ID e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5
DescriptionEnvironment with fake data about users and their locations and favorite foods.\n", "\n", @@ -117,7 +122,7 @@ "
" ], "text/plain": [ - "Task(id=0, name='Tool Usage - Relational Data', dataset_id='e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5', create_environment=, description='Environment with fake data about users and their locations and favorite foods.\\n\\nThe environment provides a set of tools that can be used to query the data.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to answer questions about relational data.\\n\\nThe dataset contains 21 examples of varying difficulty. The difficulty is measured by the number of tools that need to be used to answer the question.\\n\\nEach example is composed of a question, a reference answer, and information about the sequence in which tools should be used to answer the question.\\n\\nSuccess is measured by the ability to answer the question correctly, and efficiently.\\n', instructions=\"Please answer the user's question by using the tools provided. Do not guess the answer. Keep in mind that entities like users,foods and locations have both a name and an ID, which are not the same.\")" + "ToolUsageTask(name='Tool Usage - Relational Data', dataset_id='e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5', description='Environment with fake data about users and their locations and favorite foods.\\n\\nThe environment provides a set of tools that can be used to query the data.\\n\\nThe objective of this task is to evaluate the ability to use the provided tools to answer questions about relational data.\\n\\nThe dataset contains 21 examples of varying difficulty. The difficulty is measured by the number of tools that need to be used to answer the question.\\n\\nEach example is composed of a question, a reference answer, and information about the sequence in which tools should be used to answer the question.\\n\\nSuccess is measured by the ability to answer the question correctly, and efficiently.\\n', create_environment=, instructions=\"Please answer the user's question by using the tools provided. Do not guess the answer. Keep in mind that entities like users,foods and locations have both a name and an ID, which are not the same.\")" ] }, "execution_count": 3, @@ -177,27 +182,12 @@ "tags": [] }, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "908dfb7a73ea4332a77336ba00ed1ba4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/21 [00:00 AgentExecutor:\n", - " \"\"\"Agent Executor\"\"\"\n", - " llm = ChatOpenAI(\n", - " model=\"gpt-3.5-turbo-16k\",\n", - " temperature=0,\n", - " )\n", - "\n", - " env = task.create_environment()\n", - "\n", - " llm_with_tools = llm.bind(\n", - " functions=[format_tool_to_openai_function(t) for t in env.tools]\n", - " )\n", - " prompt = ChatPromptTemplate.from_messages(\n", - " [\n", - " (\n", - " \"system\",\n", - " \"You are a helpful assistant. Use the given tools to answer the question. Keep in mind that an ID is distinct from a name for every entity.\",\n", - " ),\n", - " MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n", - " (\"user\", \"{input}\"),\n", - " ]\n", - " )\n", - "\n", - " runnable_agent = (\n", - " {\n", - " \"input\": lambda x: x[\"question\"],\n", - " \"agent_scratchpad\": lambda x: format_to_openai_functions(\n", - " x[\"intermediate_steps\"]\n", - " ),\n", - " }\n", - " | prompt\n", - " | llm_with_tools\n", - " | OpenAIFunctionsAgentOutputParser()\n", - " )\n", - "\n", - " def _ensure_output_exists(inputs):\n", - " \"\"\"Make sure that the output key is always present.\"\"\"\n", - " if \"output\" not in inputs:\n", - " return {\"output\": \"\", **inputs}\n", - " return inputs\n", - "\n", - " return (\n", - " AgentExecutor(\n", - " agent=runnable_agent,\n", - " tools=env.tools,\n", - " handle_parsing_errors=True,\n", - " return_intermediate_steps=True,\n", - " )\n", - " | _ensure_output_exists\n", - " )" + "from langchain_benchmarks.tool_usage import agents" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0ae8c6be-899c-44a6-a89b-0fc04c2cb05c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "agent_factory = agents.OpenAIAgentFactory(task, model=\"gpt-3.5-turbo-16k\")" ] }, { @@ -296,7 +239,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, + "id": "612fb603-1401-426b-8a19-4453ad5b698a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "agent = agent_factory.create()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "id": "0e4896fa-3633-44a1-857f-80a263cf2e03", "metadata": { "tags": [] @@ -306,7 +261,7 @@ "data": { "text/plain": [ "{'question': 'who is bob?',\n", - " 'output': 'Bob is a user with the ID 21.',\n", + " 'output': 'Bob is a user with the name \"Bob\".',\n", " 'intermediate_steps': [(AgentActionMessageLog(tool='find_users_by_name', tool_input={'name': 'bob'}, log=\"\\nInvoking: `find_users_by_name` with `{'name': 'bob'}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"name\": \"bob\"\\n}', 'name': 'find_users_by_name'}})]),\n", " [{'id': 21, 'name': 'Bob'},\n", " {'id': 41, 'name': 'Donna'},\n", @@ -318,13 +273,13 @@ " 'Bob')]}" ] }, - "execution_count": 7, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "agent_factory().invoke({\"question\": \"who is bob?\"})" + "agent.invoke({\"question\": \"who is bob?\"})" ] }, { @@ -339,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "id": "513042fe-2878-44f8-ae84-05b9d521c1de", "metadata": { "tags": [] @@ -352,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "id": "2bedd9d1-fc06-4066-9f89-b874ae818d82", "metadata": { "tags": [] @@ -364,21 +319,915 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "aab7514e-a6ef-4c21-b90f-d9cbefcf5af1", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "View the evaluation results for project 'test-puzzled-fold-42' at:\n", + "https://smith.langchain.com/o/e081f11e-fbd2-41b4-9fa8-5d76c76ef854/projects/p/3d206d3f-aad1-4226-86a6-4161857e5bca?eval=true\n", + "\n", + "View all tests for Dataset Tool Usage - Relational Data at:\n", + "https://smith.langchain.com/o/e081f11e-fbd2-41b4-9fa8-5d76c76ef854/datasets/f2b5a831-8eef-4bc7-b6de-68078b87350f\n", + "[------------------------------------------------->] 21/21\n", + " Eval quantiles:\n", + " 0.25 0.5 0.75 mean mode\n", + "Intermediate steps correctness 0.0 1.0 1.0 0.571429 1.0\n", + "# steps / # expected steps 1.0 1.0 1.0 2.285714 1.0\n", + "correctness 0.0 1.0 1.0 0.666667 1.0\n" + ] + } + ], "source": [ "test_run = client.run_on_dataset(\n", " dataset_name=task.name,\n", - " llm_or_chain_factory=agent_factory,\n", + " llm_or_chain_factory=agent_factory.create,\n", " evaluation=STANDARD_AGENT_EVALUATOR,\n", " verbose=True,\n", " tags=[\"openai-functions\"],\n", ")" ] + }, + { + "cell_type": "markdown", + "id": "1b039225-01cf-481a-87a6-4e880e9b1dcd", + "metadata": {}, + "source": [ + "# Inspect\n", + "\n", + "Here, we'll take a look at the underlying results a little bit.\n", + "\n", + "A few things to note:\n", + "\n", + "* The correctness is 66% (so it's messing up a lot!)\n", + "* The number of tool invocations made by the agent can be very large; e.g., 15 invocations, when only a single invocation was actually needed." + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "6eb19db1-43b8-4866-a3d2-f211ba92ab8b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = test_run.to_dataframe()\n", + "df = pd.json_normalize(df.to_dict(orient=\"records\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "7ab5a8b9-a937-4537-b879-704284df4494", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6666666666666666" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df[\"correctness\"].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "ab7516ed-36b1-4c16-bf4a-cc49077460ad", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "df[\"num_expected_steps\"] = df[\"reference.expected_steps\"].apply(len)\n", + "df[\"actual_number_of_steps\"] = df[\"output.intermediate_steps\"].apply(len)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "50d7590d-20de-4768-ac90-adcdbfa70068", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Intermediate steps correctness# steps / # expected stepscorrectnessinput.questionoutput.questionoutput.outputoutput.intermediate_stepsreference.referencereference.order_mattersreference.expected_stepsnum_expected_stepsactual_number_of_steps
011.01What is the city for location ID 1?What is the city for location ID 1?The city for location ID 1 is New York.[(tool='get_city_for_location' tool_input={'lo...New YorkTrue[get_city_for_location]11
111.01What is the name of food with id 6?What is the name of food with id 6?The name of the food with ID 6 is Pasta.[(tool='get_food_name' tool_input={'food_id': ...PastaTrue[get_food_name]11
211.01what is eve's user id?what is eve's user id?Eve's user ID is 42.[(tool='find_users_by_name' tool_input={'name'...42True[find_users_by_name]11
3015.00get the current user idget the current user idAgent stopped due to iteration limit or time l...[(tool='get_current_user_id' tool_input={} log...35True[get_current_user_id]115
411.00How many users by the name of bob?How many users by the name of bob?There are multiple users with the name \"Bob\".[(tool='find_users_by_name' tool_input={'name'...1True[find_users_by_name]11
\n", + "
" + ], + "text/plain": [ + " Intermediate steps correctness # steps / # expected steps correctness \\\n", + "0 1 1.0 1 \n", + "1 1 1.0 1 \n", + "2 1 1.0 1 \n", + "3 0 15.0 0 \n", + "4 1 1.0 0 \n", + "\n", + " input.question output.question \\\n", + "0 What is the city for location ID 1? What is the city for location ID 1? \n", + "1 What is the name of food with id 6? What is the name of food with id 6? \n", + "2 what is eve's user id? what is eve's user id? \n", + "3 get the current user id get the current user id \n", + "4 How many users by the name of bob? How many users by the name of bob? \n", + "\n", + " output.output \\\n", + "0 The city for location ID 1 is New York. \n", + "1 The name of the food with ID 6 is Pasta. \n", + "2 Eve's user ID is 42. \n", + "3 Agent stopped due to iteration limit or time l... \n", + "4 There are multiple users with the name \"Bob\". \n", + "\n", + " output.intermediate_steps reference.reference \\\n", + "0 [(tool='get_city_for_location' tool_input={'lo... New York \n", + "1 [(tool='get_food_name' tool_input={'food_id': ... Pasta \n", + "2 [(tool='find_users_by_name' tool_input={'name'... 42 \n", + "3 [(tool='get_current_user_id' tool_input={} log... 35 \n", + "4 [(tool='find_users_by_name' tool_input={'name'... 1 \n", + "\n", + " reference.order_matters reference.expected_steps num_expected_steps \\\n", + "0 True [get_city_for_location] 1 \n", + "1 True [get_food_name] 1 \n", + "2 True [find_users_by_name] 1 \n", + "3 True [get_current_user_id] 1 \n", + "4 True [find_users_by_name] 1 \n", + "\n", + " actual_number_of_steps \n", + "0 1 \n", + "1 1 \n", + "2 1 \n", + "3 15 \n", + "4 1 " + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "ffab97b7-eda2-408d-b611-596b637e627a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "df = df.sort_values(\"actual_number_of_steps\", ascending=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "20eb92f0-9373-4741-a851-b21c41f8c203", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Intermediate steps correctness# steps / # expected stepscorrectnessinput.questionoutput.questionoutput.outputoutput.intermediate_stepsreference.referencereference.order_mattersreference.expected_stepsnum_expected_stepsactual_number_of_steps
3015.00get the current user idget the current user idAgent stopped due to iteration limit or time l...[(tool='get_current_user_id' tool_input={} log...35True[get_current_user_id]115
707.50weather in LA right now?weather in LA right now?Agent stopped due to iteration limit or time l...[(tool='find_locations_by_name' tool_input={'c...Sunny, Temperature: 75°FTrue[find_locations_by_name, get_current_weather_f...215
807.50time in chicagotime in chicagoAgent stopped due to iteration limit or time l...[(tool='find_locations_by_name' tool_input={'c...2023-11-14 11:15 AMTrue[find_locations_by_name, get_current_time_for_...215
1502.01whats the name of the city where bob lives?whats the name of the city where bob lives?The name of the city where Bob lives is Los An...[(tool='get_user_location' tool_input={'user_i...Los AngelesTrue[find_users_by_name, get_user_location, get_ci...36
2001.01do bob and alice live in the same city?do bob and alice live in the same city?No, Bob and Alice do not live in the same city...[(tool='find_users_by_name' tool_input={'name'...noFalse[find_users_by_name, get_user_location, get_ci...55
1302.00Frank who is Even's friend is allergic to dair...Frank who is Even's friend is allergic to dair...Frank is not allergic to dairy, so he can eat ...[(tool='find_users_by_name' tool_input={'name'...yesTrue[find_users_by_name, get_food_allergic_ingredi...24
1811.01do alice and charlie use the same email provider?do alice and charlie use the same email provider?No, Alice uses the email provider \"gmail.com\" ...[(tool='find_users_by_name' tool_input={'name'...noTrue[find_users_by_name, get_user_email, get_user_...33
1601.01Donna is about to go outside. Does she need an...Donna is about to go outside. Does she need an...Yes, Donna needs an umbrella because it is cur...[(tool='find_users_by_name' tool_input={'name'...yesTrue[find_users_by_name, get_user_location, get_cu...33
1411.01what is the current users favorite color and n...what is the current users favorite color and n...The current user's favorite color is yellow an...[(tool='get_current_user_id' tool_input={} log...yellow and CharlieTrue[get_current_user_id, get_user_favorite_color,...33
1111.01what is the current users favorite color?what is the current users favorite color?The current user's favorite color is yellow.[(tool='get_current_user_id' tool_input={} log...yellowTrue[get_current_user_id, get_user_favorite_color]22
1211.01eve ate a serving of sushi, what allergens was...eve ate a serving of sushi, what allergens was...If Eve ate a serving of sushi, she would have ...[(tool='find_foods_by_name' tool_input={'food'...fish, soyTrue[find_foods_by_name, get_food_allergic_ingredi...22
1011.01If i eat a serving of pizza, how many calories...If i eat a serving of pizza, how many calories...If you eat a serving of pizza, you will consum...[(tool='find_foods_by_name' tool_input={'food'...285 caloriesTrue[find_foods_by_name, get_food_calories]22
911.01list the allergens in chocolatelist the allergens in chocolateThe allergens in chocolate are milk and soy.[(tool='find_foods_by_name' tool_input={'food'...milk, soyTrue[find_foods_by_name, get_food_allergic_ingredi...22
611.01find donna's favorite colorfind donna's favorite colorDonna's favorite color is green.[(tool='find_users_by_name' tool_input={'name'...greenTrue[find_users_by_name, get_user_favorite_color]22
511.01what is alice's email address?what is alice's email address?Alice's email address is alice@gmail.com.[(tool='find_users_by_name' tool_input={'name'...alice@gmail.comTrue[find_users_by_name, get_user_email]22
111.01What is the name of food with id 6?What is the name of food with id 6?The name of the food with ID 6 is Pasta.[(tool='get_food_name' tool_input={'food_id': ...PastaTrue[get_food_name]11
411.00How many users by the name of bob?How many users by the name of bob?There are multiple users with the name \"Bob\".[(tool='find_users_by_name' tool_input={'name'...1True[find_users_by_name]11
211.01what is eve's user id?what is eve's user id?Eve's user ID is 42.[(tool='find_users_by_name' tool_input={'name'...42True[find_users_by_name]11
011.01What is the city for location ID 1?What is the city for location ID 1?The city for location ID 1 is New York.[(tool='get_city_for_location' tool_input={'lo...New YorkTrue[get_city_for_location]11
1700.00Is it likely that Donna is awake right now?Is it likely that Donna is awake right now?I'm sorry, but I don't have access to informat...[]yesTrue[find_users_by_name, get_user_location, get_cu...30
1900.00Is it likely that Donna is outside with an umb...Is it likely that Donna is outside with an umb...I'm sorry, but I don't have access to real-tim...[]yesFalse[find_users_by_name, get_user_location, get_cu...40
\n", + "
" + ], + "text/plain": [ + " Intermediate steps correctness # steps / # expected steps correctness \\\n", + "3 0 15.0 0 \n", + "7 0 7.5 0 \n", + "8 0 7.5 0 \n", + "15 0 2.0 1 \n", + "20 0 1.0 1 \n", + "13 0 2.0 0 \n", + "18 1 1.0 1 \n", + "16 0 1.0 1 \n", + "14 1 1.0 1 \n", + "11 1 1.0 1 \n", + "12 1 1.0 1 \n", + "10 1 1.0 1 \n", + "9 1 1.0 1 \n", + "6 1 1.0 1 \n", + "5 1 1.0 1 \n", + "1 1 1.0 1 \n", + "4 1 1.0 0 \n", + "2 1 1.0 1 \n", + "0 1 1.0 1 \n", + "17 0 0.0 0 \n", + "19 0 0.0 0 \n", + "\n", + " input.question \\\n", + "3 get the current user id \n", + "7 weather in LA right now? \n", + "8 time in chicago \n", + "15 whats the name of the city where bob lives? \n", + "20 do bob and alice live in the same city? \n", + "13 Frank who is Even's friend is allergic to dair... \n", + "18 do alice and charlie use the same email provider? \n", + "16 Donna is about to go outside. Does she need an... \n", + "14 what is the current users favorite color and n... \n", + "11 what is the current users favorite color? \n", + "12 eve ate a serving of sushi, what allergens was... \n", + "10 If i eat a serving of pizza, how many calories... \n", + "9 list the allergens in chocolate \n", + "6 find donna's favorite color \n", + "5 what is alice's email address? \n", + "1 What is the name of food with id 6? \n", + "4 How many users by the name of bob? \n", + "2 what is eve's user id? \n", + "0 What is the city for location ID 1? \n", + "17 Is it likely that Donna is awake right now? \n", + "19 Is it likely that Donna is outside with an umb... \n", + "\n", + " output.question \\\n", + "3 get the current user id \n", + "7 weather in LA right now? \n", + "8 time in chicago \n", + "15 whats the name of the city where bob lives? \n", + "20 do bob and alice live in the same city? \n", + "13 Frank who is Even's friend is allergic to dair... \n", + "18 do alice and charlie use the same email provider? \n", + "16 Donna is about to go outside. Does she need an... \n", + "14 what is the current users favorite color and n... \n", + "11 what is the current users favorite color? \n", + "12 eve ate a serving of sushi, what allergens was... \n", + "10 If i eat a serving of pizza, how many calories... \n", + "9 list the allergens in chocolate \n", + "6 find donna's favorite color \n", + "5 what is alice's email address? \n", + "1 What is the name of food with id 6? \n", + "4 How many users by the name of bob? \n", + "2 what is eve's user id? \n", + "0 What is the city for location ID 1? \n", + "17 Is it likely that Donna is awake right now? \n", + "19 Is it likely that Donna is outside with an umb... \n", + "\n", + " output.output \\\n", + "3 Agent stopped due to iteration limit or time l... \n", + "7 Agent stopped due to iteration limit or time l... \n", + "8 Agent stopped due to iteration limit or time l... \n", + "15 The name of the city where Bob lives is Los An... \n", + "20 No, Bob and Alice do not live in the same city... \n", + "13 Frank is not allergic to dairy, so he can eat ... \n", + "18 No, Alice uses the email provider \"gmail.com\" ... \n", + "16 Yes, Donna needs an umbrella because it is cur... \n", + "14 The current user's favorite color is yellow an... \n", + "11 The current user's favorite color is yellow. \n", + "12 If Eve ate a serving of sushi, she would have ... \n", + "10 If you eat a serving of pizza, you will consum... \n", + "9 The allergens in chocolate are milk and soy. \n", + "6 Donna's favorite color is green. \n", + "5 Alice's email address is alice@gmail.com. \n", + "1 The name of the food with ID 6 is Pasta. \n", + "4 There are multiple users with the name \"Bob\". \n", + "2 Eve's user ID is 42. \n", + "0 The city for location ID 1 is New York. \n", + "17 I'm sorry, but I don't have access to informat... \n", + "19 I'm sorry, but I don't have access to real-tim... \n", + "\n", + " output.intermediate_steps \\\n", + "3 [(tool='get_current_user_id' tool_input={} log... \n", + "7 [(tool='find_locations_by_name' tool_input={'c... \n", + "8 [(tool='find_locations_by_name' tool_input={'c... \n", + "15 [(tool='get_user_location' tool_input={'user_i... \n", + "20 [(tool='find_users_by_name' tool_input={'name'... \n", + "13 [(tool='find_users_by_name' tool_input={'name'... \n", + "18 [(tool='find_users_by_name' tool_input={'name'... \n", + "16 [(tool='find_users_by_name' tool_input={'name'... \n", + "14 [(tool='get_current_user_id' tool_input={} log... \n", + "11 [(tool='get_current_user_id' tool_input={} log... \n", + "12 [(tool='find_foods_by_name' tool_input={'food'... \n", + "10 [(tool='find_foods_by_name' tool_input={'food'... \n", + "9 [(tool='find_foods_by_name' tool_input={'food'... \n", + "6 [(tool='find_users_by_name' tool_input={'name'... \n", + "5 [(tool='find_users_by_name' tool_input={'name'... \n", + "1 [(tool='get_food_name' tool_input={'food_id': ... \n", + "4 [(tool='find_users_by_name' tool_input={'name'... \n", + "2 [(tool='find_users_by_name' tool_input={'name'... \n", + "0 [(tool='get_city_for_location' tool_input={'lo... \n", + "17 [] \n", + "19 [] \n", + "\n", + " reference.reference reference.order_matters \\\n", + "3 35 True \n", + "7 Sunny, Temperature: 75°F True \n", + "8 2023-11-14 11:15 AM True \n", + "15 Los Angeles True \n", + "20 no False \n", + "13 yes True \n", + "18 no True \n", + "16 yes True \n", + "14 yellow and Charlie True \n", + "11 yellow True \n", + "12 fish, soy True \n", + "10 285 calories True \n", + "9 milk, soy True \n", + "6 green True \n", + "5 alice@gmail.com True \n", + "1 Pasta True \n", + "4 1 True \n", + "2 42 True \n", + "0 New York True \n", + "17 yes True \n", + "19 yes False \n", + "\n", + " reference.expected_steps num_expected_steps \\\n", + "3 [get_current_user_id] 1 \n", + "7 [find_locations_by_name, get_current_weather_f... 2 \n", + "8 [find_locations_by_name, get_current_time_for_... 2 \n", + "15 [find_users_by_name, get_user_location, get_ci... 3 \n", + "20 [find_users_by_name, get_user_location, get_ci... 5 \n", + "13 [find_users_by_name, get_food_allergic_ingredi... 2 \n", + "18 [find_users_by_name, get_user_email, get_user_... 3 \n", + "16 [find_users_by_name, get_user_location, get_cu... 3 \n", + "14 [get_current_user_id, get_user_favorite_color,... 3 \n", + "11 [get_current_user_id, get_user_favorite_color] 2 \n", + "12 [find_foods_by_name, get_food_allergic_ingredi... 2 \n", + "10 [find_foods_by_name, get_food_calories] 2 \n", + "9 [find_foods_by_name, get_food_allergic_ingredi... 2 \n", + "6 [find_users_by_name, get_user_favorite_color] 2 \n", + "5 [find_users_by_name, get_user_email] 2 \n", + "1 [get_food_name] 1 \n", + "4 [find_users_by_name] 1 \n", + "2 [find_users_by_name] 1 \n", + "0 [get_city_for_location] 1 \n", + "17 [find_users_by_name, get_user_location, get_cu... 3 \n", + "19 [find_users_by_name, get_user_location, get_cu... 4 \n", + "\n", + " actual_number_of_steps \n", + "3 15 \n", + "7 15 \n", + "8 15 \n", + "15 6 \n", + "20 5 \n", + "13 4 \n", + "18 3 \n", + "16 3 \n", + "14 3 \n", + "11 2 \n", + "12 2 \n", + "10 2 \n", + "9 2 \n", + "6 2 \n", + "5 2 \n", + "1 1 \n", + "4 1 \n", + "2 1 \n", + "0 1 \n", + "17 0 \n", + "19 0 " + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "markdown", + "id": "416dce43-7e76-431c-b556-55abef32f393", + "metadata": {}, + "source": [ + "An example of a poorly behaving agent that seems to have gotten stuck in a loop!" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "5519c0ac-e241-4833-89ff-870259248bed", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[(AgentActionMessageLog(tool='find_locations_by_name', tool_input={'city': 'Chicago'}, log=\"\\nInvoking: `find_locations_by_name` with `{'city': 'Chicago'}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"city\": \"Chicago\"\\n}', 'name': 'find_locations_by_name'}})]),\n", + " [{'id': 3, 'city': 'Chicago'},\n", + " {'id': 5, 'city': 'Miami'},\n", + " {'id': 2, 'city': 'Los Angeles'},\n", + " {'id': 4, 'city': 'Houston'},\n", + " {'id': 1, 'city': 'New York'}]),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM'),\n", + " (AgentActionMessageLog(tool='get_current_time_for_location', tool_input={'location_id': 3}, log=\"\\nInvoking: `get_current_time_for_location` with `{'location_id': 3}`\\n\\n\\n\", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n \"location_id\": 3\\n}', 'name': 'get_current_time_for_location'}})]),\n", + " '2023-11-14 11:15 AM')]" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df[\"output.intermediate_steps\"].loc[8]" + ] } ], "metadata": { diff --git a/langchain_benchmarks/registration.py b/langchain_benchmarks/registration.py index ff3dcc6e..7ebf8fed 100644 --- a/langchain_benchmarks/registration.py +++ b/langchain_benchmarks/registration.py @@ -2,10 +2,10 @@ from langchain_benchmarks.extraction import email_task from langchain_benchmarks.schema import Registry -from langchain_benchmarks.tool_usage import ( - type_writer_26_funcs, +from langchain_benchmarks.tool_usage.tasks import ( type_writer, relational_data, + type_writer_26_funcs, multiverse_math, ) diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index 95f8245c..eac59c47 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -99,16 +99,18 @@ def _repr_html_(self) -> str: """Return an HTML representation of the registry.""" headers = [ "Name", + "Type", "Dataset ID", "Description", ] table = [ [ - env.name, - env.dataset_id, - env.description, + task.name, + task.__class__.__name__, + task.dataset_id, + task.description, ] - for env in self.tasks + for task in self.tasks ] return tabulate(table, headers=headers, tablefmt="html") diff --git a/langchain_benchmarks/tool_usage/agents.py b/langchain_benchmarks/tool_usage/agents.py new file mode 100644 index 00000000..5e77a2aa --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents.py @@ -0,0 +1,76 @@ +"""Code for creating an agent factory for evaluating tool usage tasks.""" +from langchain.agents import AgentExecutor +from langchain.agents.format_scratchpad import format_to_openai_functions +from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser +from langchain.chat_models import ChatOpenAI +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.schema.runnable import Runnable +from langchain.tools.render import format_tool_to_openai_function + +from langchain_benchmarks.schema import ToolUsageTask + + +def _ensure_output_exists(inputs: dict) -> dict: + """Make sure that the output key is always present.""" + if "output" not in inputs: + return {"output": "", **inputs} + return inputs + + +class OpenAIAgentFactory: + def __init__( + self, task: ToolUsageTask, *, model: str = "gpt-3.5-turbo-16k" + ) -> None: + """Create an OpenAI agent factory for the given task. + + Args: + task: The task to create an agent factory for. + model: The model to use -- this must be an open AI model. + """ + self.task = task + self.model = model + + def create(self) -> Runnable: + """Agent Executor""" + llm = ChatOpenAI( + model=self.model, + temperature=0, + ) + + env = self.task.create_environment() + + llm_with_tools = llm.bind( + functions=[format_tool_to_openai_function(t) for t in env.tools] + ) + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + self.task.instructions, + ), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ("user", "{input}"), + ] + ) + + runnable_agent = ( + { + "input": lambda x: x["question"], + "agent_scratchpad": lambda x: format_to_openai_functions( + x["intermediate_steps"] + ), + } + | prompt + | llm_with_tools + | OpenAIFunctionsAgentOutputParser() + ) + + return ( + AgentExecutor( + agent=runnable_agent, + tools=env.tools, + handle_parsing_errors=True, + return_intermediate_steps=True, + ) + | _ensure_output_exists + ) diff --git a/langchain_benchmarks/tool_usage/tasks/__init__.py b/langchain_benchmarks/tool_usage/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain_benchmarks/tool_usage/multiverse_math.py b/langchain_benchmarks/tool_usage/tasks/multiverse_math.py similarity index 100% rename from langchain_benchmarks/tool_usage/multiverse_math.py rename to langchain_benchmarks/tool_usage/tasks/multiverse_math.py diff --git a/langchain_benchmarks/tool_usage/relational_data.py b/langchain_benchmarks/tool_usage/tasks/relational_data.py similarity index 100% rename from langchain_benchmarks/tool_usage/relational_data.py rename to langchain_benchmarks/tool_usage/tasks/relational_data.py diff --git a/langchain_benchmarks/tool_usage/type_writer.py b/langchain_benchmarks/tool_usage/tasks/type_writer.py similarity index 100% rename from langchain_benchmarks/tool_usage/type_writer.py rename to langchain_benchmarks/tool_usage/tasks/type_writer.py diff --git a/langchain_benchmarks/tool_usage/type_writer_26_funcs.py b/langchain_benchmarks/tool_usage/tasks/type_writer_26_funcs.py similarity index 100% rename from langchain_benchmarks/tool_usage/type_writer_26_funcs.py rename to langchain_benchmarks/tool_usage/tasks/type_writer_26_funcs.py diff --git a/tests/unit_tests/tool_usage/test_tool_usage.py b/tests/unit_tests/tool_usage/test_tool_usage.py index a3e6f4af..670ed6bf 100644 --- a/tests/unit_tests/tool_usage/test_tool_usage.py +++ b/tests/unit_tests/tool_usage/test_tool_usage.py @@ -1,3 +1,2 @@ def test_import_tool_usage() -> None: """Test that tool_usage can be imported""" - from langchain_benchmarks.tool_usage import evaluators # noqa: F401