diff --git a/langchain_benchmarks/extraction/__init__.py b/langchain_benchmarks/extraction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain_benchmarks/extraction/email_task.py b/langchain_benchmarks/extraction/email_task.py new file mode 100644 index 00000000..b823d893 --- /dev/null +++ b/langchain_benchmarks/extraction/email_task.py @@ -0,0 +1,73 @@ +from enum import Enum +from typing import Optional, List + +from langchain.smith import RunEvalConfig +from pydantic import BaseModel, Field + +from langchain_benchmarks.schema import ExtractionTask + + +class ToneEnum(str, Enum): + """The tone of the email.""" + + positive = "positive" + negative = "negative" + + +class Email(BaseModel): + """Relevant information about an email.""" + + sender: Optional[str] = Field(None, description="The sender's name, if available") + sender_phone_number: Optional[str] = Field( + None, description="The sender's phone number, if available" + ) + sender_address: Optional[str] = Field( + None, description="The sender's address, if available" + ) + action_items: List[str] = Field( + ..., description="A list of action items requested by the email" + ) + topic: str = Field( + ..., description="High level description of what the email is about" + ) + tone: ToneEnum = Field(..., description="The tone of the email.") + + +def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig: + """Get the evaluation configuration for the email task.""" + return RunEvalConfig( + evaluators=[ + RunEvalConfig.LabeledScoreString( + criteria={ + "accuracy": """ + Score 1: The answer is incorrect and unrelated to the question or reference document. + Score 3: The answer is partially correct but has more than one omission or major errors. + Score 5: The answer is mostly correct but has more than one omission or major error. + Score 7: The answer is mostly correct but has at most one omission or major error. + Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document. + Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible. + + If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct. + If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized. + """ # noqa + }, + llm=eval_llm, + normalize_by=10.0, + ), + ], + ) + + +EmailTask = ExtractionTask( + id=4, # To be deprecated + name="Email Extraction", + dataset_id="https://smith.langchain.com/public/36bdfe7d-3cd1-4b36-b957-d12d95810a2b/d", + model=Email, + description="""\ +A dataset of 42 real emails deduped from a spam folder, with semantic HTML tags removed, \ +as well as a script for initial extraction and formatting of other emails from \ +an arbitrary .mbox file like the one exported by Gmail. + +Some additional cleanup of the data was done by hand after the initial pass. + """, +) diff --git a/langchain_benchmarks/registration.py b/langchain_benchmarks/registration.py index cb1c0f3a..a6720422 100644 --- a/langchain_benchmarks/registration.py +++ b/langchain_benchmarks/registration.py @@ -4,7 +4,8 @@ from tabulate import tabulate -from langchain_benchmarks.schema import Task +from langchain_benchmarks.schema import ToolUsageTask, ExtractionTask +from langchain_benchmarks.extraction import email_task from langchain_benchmarks.tool_usage.environments import ( relational_data, type_writer, @@ -15,9 +16,9 @@ @dataclasses.dataclass(frozen=True) class Registry: - tasks: Sequence[Task] + tasks: Sequence[ToolUsageTask] - def get_task(self, name_or_id: Union[int, str]) -> Task: + def get_task(self, name_or_id: Union[int, str]) -> ToolUsageTask: """Get the environment with the given name.""" for env in self.tasks: if env.name == name_or_id or env.id == name_or_id: @@ -58,7 +59,7 @@ def _repr_html_(self) -> str: ] return tabulate(table, headers=headers, tablefmt="html") - def __getitem__(self, key: Union[int, str]) -> Task: + def __getitem__(self, key: Union[int, str]) -> ToolUsageTask: """Get an environment from the registry.""" if isinstance(key, slice): raise NotImplementedError("Slicing is not supported.") @@ -72,7 +73,7 @@ def __getitem__(self, key: Union[int, str]) -> Task: # Using lower case naming to make a bit prettier API when used in a notebook registry = Registry( tasks=[ - Task( + ToolUsageTask( id=0, name="Tool Usage - Relational Data", dataset_id=relational_data.DATASET_ID, @@ -103,7 +104,7 @@ def __getitem__(self, key: Union[int, str]) -> Task: """ ), ), - Task( + ToolUsageTask( id=1, name="Tool Usage - Typewriter (1 func)", dataset_id="placeholder", @@ -131,7 +132,7 @@ def __getitem__(self, key: Union[int, str]) -> Task: """ ), ), - Task( + ToolUsageTask( id=2, name="Tool Usage - Typewriter", dataset_id="placeholder", @@ -161,7 +162,7 @@ def __getitem__(self, key: Union[int, str]) -> Task: """ ), ), - Task( + ToolUsageTask( id=3, name="Multiverse Math", dataset_id="placeholder", @@ -187,5 +188,6 @@ def __getitem__(self, key: Union[int, str]) -> Task: """ ), ), + email_task.EmailTask, ] ) diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index 06717e6f..c6dc2642 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -1,13 +1,14 @@ """Schema for the Langchain Benchmarks.""" import dataclasses -from typing import List, Callable, Any, Optional +from typing import List, Callable, Any, Optional, Type from langchain.tools import BaseTool +from pydantic import BaseModel from tabulate import tabulate @dataclasses.dataclass(frozen=True) -class Environment: +class ToolUsageEnvironment: """An instance of an environment for tool usage.""" tools: List[BaseTool] @@ -18,8 +19,8 @@ class Environment: @dataclasses.dataclass(frozen=True) -class Task: - """A definition for a task.""" +class BaseTask: + """A definition of a task.""" id: int """The ID of the environment.""" @@ -28,31 +29,24 @@ class Task: dataset_id: str """The ID of the langsmith public dataset. - + This dataset contains expected inputs/outputs for the environment, and can be used to evaluate the performance of a model/agent etc. """ - create_environment: Callable[ - [], Environment - ] # Specialized for tool usage; refactor potentially - """Factory that returns an environment.""" - description: str """Description of the task for a data science practitioner. - + This can contain information about the task, the dataset, the tools available etc. """ - instructions: str - """Instructions for the agent/chain/llm.""" - def _repr_html_(self) -> str: """Return an HTML representation of the environment.""" table = [ ["ID", self.id], ["Name", self.name], + ["Type", self.__class__.__name__], ["Dataset ID", self.dataset_id], ["Description", self.description[:100] + "..."], ] @@ -60,3 +54,22 @@ def _repr_html_(self) -> str: table, tablefmt="html", ) + + +@dataclasses.dataclass(frozen=True) +class ToolUsageTask(BaseTask): + """A definition for a task.""" + + create_environment: Callable[[], ToolUsageEnvironment] + """Factory that returns an environment.""" + + instructions: str + """Instructions for the agent/chain/llm.""" + + +@dataclasses.dataclass(frozen=True) +class ExtractionTask(BaseTask): + """A definition for an extraction task.""" + + model: Type[BaseModel] + """Get the model for the task.""" diff --git a/langchain_benchmarks/tool_usage/environments/multiverse_math.py b/langchain_benchmarks/tool_usage/environments/multiverse_math.py index b52d95cf..c300f785 100644 --- a/langchain_benchmarks/tool_usage/environments/multiverse_math.py +++ b/langchain_benchmarks/tool_usage/environments/multiverse_math.py @@ -14,7 +14,7 @@ from langchain.tools import tool, BaseTool -from langchain_benchmarks.schema import Environment +from langchain_benchmarks.schema import ToolUsageEnv def multiply(a: float, b: float) -> float: @@ -76,13 +76,13 @@ def negate(a: float) -> float: # PUBLIC API -def get_environment() -> Environment: +def get_environment() -> ToolUsageEnv: """Create an environment.""" tools = cast( List[BaseTool], [tool(func) for func in [multiply, add, divide, subtract, power, log, negate]], ) - return Environment( + return ToolUsageEnv( tools=tools, read_state=None, ) diff --git a/langchain_benchmarks/tool_usage/environments/relational_data.py b/langchain_benchmarks/tool_usage/environments/relational_data.py index d7ce35e2..ba1eaf13 100644 --- a/langchain_benchmarks/tool_usage/environments/relational_data.py +++ b/langchain_benchmarks/tool_usage/environments/relational_data.py @@ -12,7 +12,7 @@ from langchain.tools import BaseTool, tool -from langchain_benchmarks.schema import Environment +from langchain_benchmarks.schema import ToolUsageEnvironment USER_DATA = [ # IDs are not consecutive to prevent agents from guessing the ID @@ -397,9 +397,9 @@ def get_tools() -> List[BaseTool]: return [tool(f) for f in functions] -def get_environment() -> Environment: +def get_environment() -> ToolUsageEnvironment: """Create an environment.""" - return Environment( + return ToolUsageEnvironment( tools=get_tools(), read_state=None, ) diff --git a/langchain_benchmarks/tool_usage/environments/type_writer.py b/langchain_benchmarks/tool_usage/environments/type_writer.py index 5b1e55f8..2165762b 100644 --- a/langchain_benchmarks/tool_usage/environments/type_writer.py +++ b/langchain_benchmarks/tool_usage/environments/type_writer.py @@ -8,7 +8,7 @@ from langchain.tools import BaseTool, tool -from langchain_benchmarks.schema import Environment +from langchain_benchmarks.schema import ToolUsageEnv @dataclasses.dataclass @@ -32,7 +32,7 @@ def type_letter(letter: str) -> str: # PUBLIC API -def get_environment() -> Environment: +def get_environment() -> ToolUsageEnv: """Create tools and state reader. Attention: this is a factory function, so it will create a new environment @@ -51,7 +51,7 @@ def _read_state() -> Any: # tools = cast(List[BaseTool], [tool(f) for f in functions]) tools = cast(List[BaseTool], [tool(function(paper))]) - return Environment( + return ToolUsageEnv( tools=tools, read_state=_read_state, ) diff --git a/langchain_benchmarks/tool_usage/environments/type_writer_26_funcs.py b/langchain_benchmarks/tool_usage/environments/type_writer_26_funcs.py index 8150234a..002156f5 100644 --- a/langchain_benchmarks/tool_usage/environments/type_writer_26_funcs.py +++ b/langchain_benchmarks/tool_usage/environments/type_writer_26_funcs.py @@ -8,7 +8,7 @@ from langchain.tools import BaseTool, tool -from langchain_benchmarks.schema import Environment +from langchain_benchmarks.schema import ToolUsageEnv @dataclasses.dataclass @@ -40,7 +40,7 @@ def _get_available_functions(paper: Paper) -> List[Callable]: # PUBLIC API -def get_environment() -> Environment: +def get_environment() -> ToolUsageEnv: """Create tools and state reader. Attention: this is a factory function, so it will create a new environment @@ -58,7 +58,7 @@ def _read_state() -> Any: tools = cast(List[BaseTool], [tool(f) for f in functions]) - return Environment( + return ToolUsageEnv( tools=tools, read_state=_read_state, ) diff --git a/tests/unit_tests/extraction/__init__.py b/tests/unit_tests/extraction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/extraction/test_email_extraction.py b/tests/unit_tests/extraction/test_email_extraction.py new file mode 100644 index 00000000..57bf3530 --- /dev/null +++ b/tests/unit_tests/extraction/test_email_extraction.py @@ -0,0 +1,3 @@ +def test_email_extraction() -> None: + """Try to import the email task.""" + from langchain_benchmarks.extraction import email_task # noqa: F401