diff --git a/scripts/tool_benchmarks.py b/scripts/tool_benchmarks.py index 41e3523..0118e49 100644 --- a/scripts/tool_benchmarks.py +++ b/scripts/tool_benchmarks.py @@ -40,49 +40,40 @@ client = Client() # Launch langsmith client for cloning datasets +def get_few_shot_messages(task_name): + if task_name == "Multiverse Math": + uncleaned_examples = [ + e + for e in client.list_examples( + dataset_name="multiverse-math-examples-for-few-shot" + ) + ] + few_shot_messages = [] + examples = [] + for i in range(len(uncleaned_examples)): + converted_messages = convert_to_messages( + uncleaned_examples[i].outputs["output"] + ) + examples.append( + # The message at index 1 is the human message asking the actual math question (0th message is system prompt) + { + "question": converted_messages[1].content, + "messages": [ + m + for m in converted_messages + if isinstance(m, SystemMessage) == False + ], + } + ) + few_shot_messages += converted_messages + + return examples, [ + m for m in few_shot_messages if not isinstance(m, SystemMessage) + ] + else: + raise ValueError("Few shot messages not supported for this dataset") -experiment_uuid = uuid.uuid4().hex[:4] -today = datetime.date.today().isoformat() -for task in registry.tasks: - if task.type != "ToolUsageTask": - continue - # This is a small test dataset that can be used to verify - # that everything is set up correctly prior to running over - # all results. We may remove it in the future. - if task.name != "Multiverse Math": - continue - - dataset_name = task.name - - uncleaned_examples = [ - e - for e in client.list_examples( - dataset_name="multiverse-math-examples-for-few-shot" - ) - ] - few_shot_messages = [] - examples = [] - for i in range(len(uncleaned_examples)): - converted_messages = convert_to_messages( - uncleaned_examples[i].outputs["output"] - ) - examples.append( - # The message at index 1 is the human message (0th message is system prompt) - { - "question": converted_messages[1].content, - "messages": [ - m - for m in converted_messages - if isinstance(m, SystemMessage) == False - ], - } - ) - few_shot_messages += converted_messages - - few_shot_messages = [ - m for m in few_shot_messages if not isinstance(m, SystemMessage) - ] - +def get_few_shot_str_from_messages(few_shot_messages): few_shot_str = "" for m in few_shot_messages: if isinstance(m.content, list): @@ -101,8 +92,10 @@ few_shot_str += "\n" +def get_prompts(task_name, **kwargs): + if task_name == "Multiverse Math": example_selector = SemanticSimilarityExampleSelector.from_examples( - examples, + kwargs['examples'], OpenAIEmbeddings(), FAISS, k=3, @@ -115,8 +108,7 @@ example_selector=example_selector, example_prompt=MessagesPlaceholder("messages"), ) - - prompts = [ + return [ ( hub.pull("multiverse-math-no-few-shot"), "no-few-shot", @@ -141,33 +133,46 @@ MessagesPlaceholder("agent_scratchpad"), ] ), - "few-shot-semantic", + "few-shot-semantic-openai-embeddinga", ), ] - for model_name, model in tests[-2:-1]: - rate_limiter = RateLimiter(requests_per_second=1) - - print(f"Benchmarking {task.name} with model: {model_name}") - eval_config = task.get_eval_config() - - for prompt, prompt_name in prompts: - agent_factory = StandardAgentFactory( - task, model, prompt, rate_limiter=rate_limiter - ) - - client.run_on_dataset( - dataset_name=dataset_name, - llm_or_chain_factory=agent_factory, - evaluation=eval_config, - verbose=False, - project_name=f"{model_name}-{task.name}-{prompt_name}-{experiment_uuid}", - concurrency_level=5, - project_metadata={ - "model": model_name, - "id": experiment_uuid, - "task": task.name, - "date": today, - "langchain_benchmarks_version": __version__, - }, - ) +experiment_uuid = uuid.uuid4().hex[:4] +today = datetime.date.today().isoformat() +for task in registry.tasks: + if task.type != "ToolUsageTask" or task.name != "Multiverse Math": + continue + + dataset_name = task.name + + examples, few_shot_messages = get_few_shot_messages(task.name) + few_shot_str = get_few_shot_str_from_messages(few_shot_messages) + prompts = get_prompts(task.name,examples=examples) + + # Fireworks API limit reached, so only test first 4 models for now + for model_name, model in tests[-2:-1]: + rate_limiter = RateLimiter(requests_per_second=1) + + print(f"Benchmarking {task.name} with model: {model_name}") + eval_config = task.get_eval_config() + + for prompt, prompt_name in prompts: + agent_factory = StandardAgentFactory( + task, model, prompt, rate_limiter=rate_limiter + ) + + client.run_on_dataset( + dataset_name=dataset_name, + llm_or_chain_factory=agent_factory, + evaluation=eval_config, + verbose=False, + project_name=f"{model_name}-{task.name}-{prompt_name}-{experiment_uuid}", + concurrency_level=5, + project_metadata={ + "model": model_name, + "id": experiment_uuid, + "task": task.name, + "date": today, + "langchain_benchmarks_version": __version__, + }, + )