Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
isahers1 committed Jul 22, 2024
1 parent a7879c1 commit 12e9107
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions scripts/tool_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
from langchain_openai import OpenAIEmbeddings
from langsmith.client import Client
from langchain_benchmarks import __version__, registry
#from langchain_benchmarks import __version__, registry
from langchain_benchmarks.rate_limiting import RateLimiter
from langchain_benchmarks.tool_usage.agents import StandardAgentFactory
import sys
sys.path.append("./..")
from langchain_benchmarks import __version__, registry
from langchain_benchmarks.tool_usage.tasks.multiverse_math import *
from langchain.chat_models import init_chat_model
from langsmith.evaluation import evaluate
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.tools import tool

tests = [
("claude-3-haiku-20240307","anthropic",),
Expand All @@ -28,8 +32,8 @@
("gpt-4o-mini","openai"),
("llama3-groq-70b-8192-tool-use-preview","groq"),
("llama3-groq-8b-8192-tool-use-preview","groq"),
("gemini-1.5-pro","google_vertexai"),
("gemini-1.5-flash","google_vertexai")
("gemini-1.5-pro","google_genai"),
("gemini-1.5-flash","google_genai")
]

client = Client() # Launch langsmith client for cloning datasets
Expand Down Expand Up @@ -150,9 +154,13 @@ def get_prompts(task_name, **kwargs):

def predict_from_callable(callable,instructions):
def predict(run):
return callable.invoke({"question":run['question'],"instructions":instructions})['output']
return callable.invoke({"question":run['question'],"instructions":instructions})
return predict

def pi(a: float) -> float:
"""Returns a precise value of PI for this alternate universe."""
return math.e

experiment_uuid = uuid.uuid4().hex[:4]
today = datetime.date.today().isoformat()
for task in registry.tasks:
Expand All @@ -165,7 +173,7 @@ def predict(run):
few_shot_str, few_shot_three_str = get_few_shot_str_from_messages(few_shot_messages,few_shot_three_messages)
prompts = get_prompts(task.name,examples=examples,few_shot_three_messages=few_shot_three_messages,few_shot_three_str=few_shot_three_str)

for model_name, model_provider in tests:
for model_name, model_provider in tests[9:]:
model = init_chat_model(model_name,model_provider=model_provider)
rate_limiter = RateLimiter(requests_per_second=1)

Expand All @@ -175,9 +183,14 @@ def predict(run):
for prompt, prompt_name in prompts[:-1]:

tools = task.create_environment().tools
if "google" in model_provider:
tools[9] = tool(pi)
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
agent_executor = AgentExecutor(agent=agent, tools=tools, return_intermediate_steps=True)


'''
# Legacy way of running, migrate to evaluate
agent_factory = StandardAgentFactory(
task, model, prompt, rate_limiter=rate_limiter
)
Expand All @@ -197,11 +210,10 @@ def predict(run):
},
)
'''
print(agent_executor.invoke({"question":"placeholder question","instructions":task.instructions}))
evaluate(
predict_from_callable(agent_executor,task.instructions),
data=dataset_name,
evaluators=eval_config.evaluators,
evaluators=eval_config.custom_evaluators,
max_concurrency=5,
metadata={
"model": model_name,
Expand All @@ -210,6 +222,7 @@ def predict(run):
"date": today,
"langchain_benchmarks_version": __version__,
},
experiment_prefix=f"{model_name}-{task.name}-{prompt_name}"
)


0 comments on commit 12e9107

Please sign in to comment.