Skip to content

Commit

Permalink
editing tool benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
isahers1 committed Jul 16, 2024
1 parent be68210 commit e288225
Showing 1 changed file with 77 additions and 72 deletions.
149 changes: 77 additions & 72 deletions scripts/tool_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,49 +40,40 @@

client = Client() # Launch langsmith client for cloning datasets

def get_few_shot_messages(task_name):
if task_name == "Multiverse Math":
uncleaned_examples = [
e
for e in client.list_examples(
dataset_name="multiverse-math-examples-for-few-shot"
)
]
few_shot_messages = []
examples = []
for i in range(len(uncleaned_examples)):
converted_messages = convert_to_messages(
uncleaned_examples[i].outputs["output"]
)
examples.append(
# The message at index 1 is the human message asking the actual math question (0th message is system prompt)
{
"question": converted_messages[1].content,
"messages": [
m
for m in converted_messages
if isinstance(m, SystemMessage) == False
],
}
)
few_shot_messages += converted_messages

return examples, [
m for m in few_shot_messages if not isinstance(m, SystemMessage)
]
else:
raise ValueError("Few shot messages not supported for this dataset")

experiment_uuid = uuid.uuid4().hex[:4]
today = datetime.date.today().isoformat()
for task in registry.tasks:
if task.type != "ToolUsageTask":
continue
# This is a small test dataset that can be used to verify
# that everything is set up correctly prior to running over
# all results. We may remove it in the future.
if task.name != "Multiverse Math":
continue

dataset_name = task.name

uncleaned_examples = [
e
for e in client.list_examples(
dataset_name="multiverse-math-examples-for-few-shot"
)
]
few_shot_messages = []
examples = []
for i in range(len(uncleaned_examples)):
converted_messages = convert_to_messages(
uncleaned_examples[i].outputs["output"]
)
examples.append(
# The message at index 1 is the human message (0th message is system prompt)
{
"question": converted_messages[1].content,
"messages": [
m
for m in converted_messages
if isinstance(m, SystemMessage) == False
],
}
)
few_shot_messages += converted_messages

few_shot_messages = [
m for m in few_shot_messages if not isinstance(m, SystemMessage)
]

def get_few_shot_str_from_messages(few_shot_messages):
few_shot_str = ""
for m in few_shot_messages:
if isinstance(m.content, list):
Expand All @@ -101,8 +92,10 @@

few_shot_str += "\n"

def get_prompts(task_name, **kwargs):
if task_name == "Multiverse Math":
example_selector = SemanticSimilarityExampleSelector.from_examples(
examples,
kwargs['examples'],
OpenAIEmbeddings(),
FAISS,
k=3,
Expand All @@ -115,8 +108,7 @@
example_selector=example_selector,
example_prompt=MessagesPlaceholder("messages"),
)

prompts = [
return [
(
hub.pull("multiverse-math-no-few-shot"),
"no-few-shot",
Expand All @@ -141,33 +133,46 @@
MessagesPlaceholder("agent_scratchpad"),
]
),
"few-shot-semantic",
"few-shot-semantic-openai-embeddinga",
),
]

for model_name, model in tests[-2:-1]:
rate_limiter = RateLimiter(requests_per_second=1)

print(f"Benchmarking {task.name} with model: {model_name}")
eval_config = task.get_eval_config()

for prompt, prompt_name in prompts:
agent_factory = StandardAgentFactory(
task, model, prompt, rate_limiter=rate_limiter
)

client.run_on_dataset(
dataset_name=dataset_name,
llm_or_chain_factory=agent_factory,
evaluation=eval_config,
verbose=False,
project_name=f"{model_name}-{task.name}-{prompt_name}-{experiment_uuid}",
concurrency_level=5,
project_metadata={
"model": model_name,
"id": experiment_uuid,
"task": task.name,
"date": today,
"langchain_benchmarks_version": __version__,
},
)
experiment_uuid = uuid.uuid4().hex[:4]
today = datetime.date.today().isoformat()
for task in registry.tasks:
if task.type != "ToolUsageTask" or task.name != "Multiverse Math":
continue

dataset_name = task.name

examples, few_shot_messages = get_few_shot_messages(task.name)
few_shot_str = get_few_shot_str_from_messages(few_shot_messages)
prompts = get_prompts(task.name,examples=examples)

# Fireworks API limit reached, so only test first 4 models for now
for model_name, model in tests[-2:-1]:
rate_limiter = RateLimiter(requests_per_second=1)

print(f"Benchmarking {task.name} with model: {model_name}")
eval_config = task.get_eval_config()

for prompt, prompt_name in prompts:
agent_factory = StandardAgentFactory(
task, model, prompt, rate_limiter=rate_limiter
)

client.run_on_dataset(
dataset_name=dataset_name,
llm_or_chain_factory=agent_factory,
evaluation=eval_config,
verbose=False,
project_name=f"{model_name}-{task.name}-{prompt_name}-{experiment_uuid}",
concurrency_level=5,
project_metadata={
"model": model_name,
"id": experiment_uuid,
"task": task.name,
"date": today,
"langchain_benchmarks_version": __version__,
},
)

0 comments on commit e288225

Please sign in to comment.