Skip to content

Commit

Permalink
draft state
Browse files Browse the repository at this point in the history
  • Loading branch information
isahers1 committed Jul 22, 2024
1 parent e288225 commit a7879c1
Showing 1 changed file with 71 additions and 34 deletions.
105 changes: 71 additions & 34 deletions scripts/tool_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,34 @@
import uuid

from langchain import hub
from langchain_anthropic import ChatAnthropic
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages.utils import convert_to_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
from langchain_fireworks import ChatFireworks
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langsmith.client import Client

from langchain_benchmarks import __version__, registry
from langchain_benchmarks.rate_limiting import RateLimiter
from langchain_benchmarks.tool_usage.agents import StandardAgentFactory
from langchain_benchmarks.tool_usage.tasks.multiverse_math import *
from langchain.chat_models import init_chat_model
from langsmith.evaluation import evaluate
from langchain.agents import AgentExecutor, create_tool_calling_agent

tests = [
(
"claude-3-haiku-20240307",
ChatAnthropic(model="claude-3-haiku-20240307", temperature=0),
),
(
"claude-3-sonnet-20240229",
ChatAnthropic(model="claude-3-sonnet-20240229", temperature=0),
),
("gpt-3.5-turbo-0125", ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)),
(
"gpt-4-turbo-2024-04-09",
ChatOpenAI(model="gpt-4-turbo-2024-04-09", temperature=0),
),
(
"accounts/fireworks/models/firefunction-v2",
ChatFireworks(model="accounts/fireworks/models/firefunction-v2", temperature=0),
),
("claude-3-haiku-20240307","anthropic",),
("claude-3-sonnet-20240229","anthropic",),
("claude-3-opus-20240229","anthropic",),
("claude-3-5-sonnet-20240620","anthropic",),
("gpt-3.5-turbo-0125", "openai"),
("gpt-4o","openai",),
("gpt-4o-mini","openai"),
("llama3-groq-70b-8192-tool-use-preview","groq"),
("llama3-groq-8b-8192-tool-use-preview","groq"),
("gemini-1.5-pro","google_vertexai"),
("gemini-1.5-flash","google_vertexai")
]

client = Client() # Launch langsmith client for cloning datasets
Expand All @@ -49,6 +43,7 @@ def get_few_shot_messages(task_name):
)
]
few_shot_messages = []
few_shot_three_messages = []
examples = []
for i in range(len(uncleaned_examples)):
converted_messages = convert_to_messages(
Expand All @@ -66,14 +61,16 @@ def get_few_shot_messages(task_name):
}
)
few_shot_messages += converted_messages
if i < 3:
few_shot_three_messages += converted_messages

return examples, [
m for m in few_shot_messages if not isinstance(m, SystemMessage)
]
], [m for m in few_shot_three_messages if not isinstance(m, SystemMessage)]
else:
raise ValueError("Few shot messages not supported for this dataset")

def get_few_shot_str_from_messages(few_shot_messages):
def turn_messages_to_str(few_shot_messages):
few_shot_str = ""
for m in few_shot_messages:
if isinstance(m.content, list):
Expand All @@ -91,6 +88,12 @@ def get_few_shot_str_from_messages(few_shot_messages):
few_shot_str += f"AI message: {m.content}"

few_shot_str += "\n"
return few_shot_str

def get_few_shot_str_from_messages(few_shot_messages, few_shot_three_messages):
few_shot_str = turn_messages_to_str(few_shot_messages)
few_shot_three_str = turn_messages_to_str(few_shot_three_messages)
return few_shot_str, few_shot_three_str

def get_prompts(task_name, **kwargs):
if task_name == "Multiverse Math":
Expand All @@ -110,17 +113,25 @@ def get_prompts(task_name, **kwargs):
)
return [
(
hub.pull("multiverse-math-no-few-shot"),
client.pull_prompt("langchain-ai/multiverse-math-no-few-shot"),
"no-few-shot",
),
(
hub.pull("multiverse-math-few-shot-messages"),
client.pull_prompt("langchain-ai/multiverse-math-few-shot-messages"),
"few-shot-messages",
),
(
hub.pull("multiverse-math-few-shot-str"),
client.pull_prompt("langchain-ai/multiverse-math-few-shot-str"),
"few-shot-string",
),
(
client.pull_prompt("langchain-ai/multiverse-math-few-shot-3-messages"),
"few-shot-three-messages",
),
(
client.pull_prompt("langchain-ai/multiverse-math-few-shot-3-str"),
"few-shot-three-strings",
),
(
ChatPromptTemplate.from_messages(
[
Expand All @@ -137,6 +148,11 @@ def get_prompts(task_name, **kwargs):
),
]

def predict_from_callable(callable,instructions):
def predict(run):
return callable.invoke({"question":run['question'],"instructions":instructions})['output']
return predict

experiment_uuid = uuid.uuid4().hex[:4]
today = datetime.date.today().isoformat()
for task in registry.tasks:
Expand All @@ -145,22 +161,26 @@ def get_prompts(task_name, **kwargs):

dataset_name = task.name

examples, few_shot_messages = get_few_shot_messages(task.name)
few_shot_str = get_few_shot_str_from_messages(few_shot_messages)
prompts = get_prompts(task.name,examples=examples)
examples, few_shot_messages, few_shot_three_messages = get_few_shot_messages(task.name)
few_shot_str, few_shot_three_str = get_few_shot_str_from_messages(few_shot_messages,few_shot_three_messages)
prompts = get_prompts(task.name,examples=examples,few_shot_three_messages=few_shot_three_messages,few_shot_three_str=few_shot_three_str)

# Fireworks API limit reached, so only test first 4 models for now
for model_name, model in tests[-2:-1]:
for model_name, model_provider in tests:
model = init_chat_model(model_name,model_provider=model_provider)
rate_limiter = RateLimiter(requests_per_second=1)

print(f"Benchmarking {task.name} with model: {model_name}")
eval_config = task.get_eval_config()

for prompt, prompt_name in prompts:

for prompt, prompt_name in prompts[:-1]:

tools = task.create_environment().tools
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
'''
agent_factory = StandardAgentFactory(
task, model, prompt, rate_limiter=rate_limiter
)

client.run_on_dataset(
dataset_name=dataset_name,
llm_or_chain_factory=agent_factory,
Expand All @@ -176,3 +196,20 @@ def get_prompts(task_name, **kwargs):
"langchain_benchmarks_version": __version__,
},
)
'''
print(agent_executor.invoke({"question":"placeholder question","instructions":task.instructions}))
evaluate(
predict_from_callable(agent_executor,task.instructions),
data=dataset_name,
evaluators=eval_config.evaluators,
max_concurrency=5,
metadata={
"model": model_name,
"id": experiment_uuid,
"task": task.name,
"date": today,
"langchain_benchmarks_version": __version__,
},
)


0 comments on commit a7879c1

Please sign in to comment.