From e07ac06b30bc96a4003fc39f9960c0d0067eb9d7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 21 Nov 2023 09:38:14 -0500 Subject: [PATCH] Add more extraction code (#37) Add more extraction code --- langchain_benchmarks/extraction/evaluators.py | 27 ++++++++ .../extraction/implementations.py | 68 +++++++++++++++++++ .../extraction/tasks/__init__.py | 0 .../extraction/{ => tasks}/email_task.py | 40 ++++------- langchain_benchmarks/registration.py | 2 +- langchain_benchmarks/schema.py | 13 +++- .../extraction/test_email_extraction.py | 1 - .../extraction/test_import_stuff.py | 3 + 8 files changed, 124 insertions(+), 30 deletions(-) create mode 100644 langchain_benchmarks/extraction/evaluators.py create mode 100644 langchain_benchmarks/extraction/implementations.py create mode 100644 langchain_benchmarks/extraction/tasks/__init__.py rename langchain_benchmarks/extraction/{ => tasks}/email_task.py (51%) create mode 100644 tests/unit_tests/extraction/test_import_stuff.py diff --git a/langchain_benchmarks/extraction/evaluators.py b/langchain_benchmarks/extraction/evaluators.py new file mode 100644 index 00000000..5a6f29a1 --- /dev/null +++ b/langchain_benchmarks/extraction/evaluators.py @@ -0,0 +1,27 @@ +from langchain.smith import RunEvalConfig +from pydantic import BaseModel + + +def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig: + """Get the evaluation configuration for the email task.""" + return RunEvalConfig( + evaluators=[ + RunEvalConfig.LabeledScoreString( + criteria={ + "accuracy": """ + Score 1: The answer is incorrect and unrelated to the question or reference document. + Score 3: The answer is partially correct but has more than one omission or major errors. + Score 5: The answer is mostly correct but has more than one omission or major error. + Score 7: The answer is mostly correct but has at most one omission or major error. + Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document. + Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible. + + If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct. + If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized. + """ # noqa + }, + llm=eval_llm, + normalize_by=10.0, + ), + ], + ) diff --git a/langchain_benchmarks/extraction/implementations.py b/langchain_benchmarks/extraction/implementations.py new file mode 100644 index 00000000..a1238a6a --- /dev/null +++ b/langchain_benchmarks/extraction/implementations.py @@ -0,0 +1,68 @@ +"""Default implementations of LLMs that can be used for extraction.""" +from typing import Type, Optional, List, Any, Dict + +from langchain.chains.openai_functions import convert_to_openai_function +from langchain.chat_models import ChatOpenAI +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser +from langchain.schema.runnable import Runnable +from langsmith.client import Client +from pydantic import BaseModel + +from langchain_benchmarks.extraction.evaluators import get_eval_config +from langchain_benchmarks.schema import ExtractionTask + +# PUBLIC API + + +def create_openai_function_based_extractor( + llm: Runnable, + schema: Type[BaseModel], +) -> Runnable[dict, dict]: + """Create an extraction chain that uses an LLM to extract a schema. + + The underlying functionality is exclusively for LLMs that support + extraction using openai functions format. + + Args: + llm: The LLM to use for extraction. + schema: The schema to extract. + + Returns: + An llm that will extract the schema + """ + openai_functions = [convert_to_openai_function(schema)] + llm_kwargs = { + "functions": openai_functions, + "function_call": {"name": openai_functions[0]["name"]}, + } + output_parser = JsonOutputFunctionsParser() + extraction_chain = ( + llm.bind(**llm_kwargs) | output_parser | (lambda x: {"output": x}) + ) + return extraction_chain + + +def run_on_dataset( + task: ExtractionTask, + llm: Runnable, + *, + tags: Optional[List[str]] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Run an LLM on a dataset. + + Args: + task: The task to run on. + llm: The LLM to run. + tags: The tags to use for the run. + kwargs: Additional arguments to pass to the client. + """ + client = Client() + eval_llm = ChatOpenAI(model="gpt-4", temperature=0.0, model_kwargs={"seed": 42}) + return client.run_on_dataset( + dataset_name=task.name, + llm_or_chain_factory=create_openai_function_based_extractor(llm, task.schema), + evaluation=get_eval_config(eval_llm), + tags=tags, + **kwargs, + ) diff --git a/langchain_benchmarks/extraction/tasks/__init__.py b/langchain_benchmarks/extraction/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain_benchmarks/extraction/email_task.py b/langchain_benchmarks/extraction/tasks/email_task.py similarity index 51% rename from langchain_benchmarks/extraction/email_task.py rename to langchain_benchmarks/extraction/tasks/email_task.py index e03f138d..d216a4a2 100644 --- a/langchain_benchmarks/extraction/email_task.py +++ b/langchain_benchmarks/extraction/tasks/email_task.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Optional, List -from langchain.smith import RunEvalConfig +from langchain.prompts import ChatPromptTemplate from pydantic import BaseModel, Field from langchain_benchmarks.schema import ExtractionTask @@ -33,35 +33,22 @@ class Email(BaseModel): tone: ToneEnum = Field(..., description="The tone of the email.") -def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig: - """Get the evaluation configuration for the email task.""" - return RunEvalConfig( - evaluators=[ - RunEvalConfig.LabeledScoreString( - criteria={ - "accuracy": """ - Score 1: The answer is incorrect and unrelated to the question or reference document. - Score 3: The answer is partially correct but has more than one omission or major errors. - Score 5: The answer is mostly correct but has more than one omission or major error. - Score 7: The answer is mostly correct but has at most one omission or major error. - Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document. - Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible. - - If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct. - If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized. - """ # noqa - }, - llm=eval_llm, - normalize_by=10.0, - ), - ], - ) - +# This is a default prompt that works for chat models. +DEFAULT_CHAT_MODEL_PROMPT = ChatPromptTemplate.from_messages( + [ + ("system", "You are an expert researcher."), + ( + "human", + "What can you tell me about the following email? Make sure to " + "answer in the correct format: {schema}", + ), + ] +) EMAIL_EXTRACTION_TASK = ExtractionTask( name="Email Extraction", dataset_id="https://smith.langchain.com/public/36bdfe7d-3cd1-4b36-b957-d12d95810a2b/d", - model=Email, + schema=Email, description="""\ A dataset of 42 real emails deduped from a spam folder, with semantic HTML tags removed, \ as well as a script for initial extraction and formatting of other emails from \ @@ -71,4 +58,5 @@ def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig: See https://github.com/jacoblee93/oss-model-extraction-evals. """, + instructions=DEFAULT_CHAT_MODEL_PROMPT, ) diff --git a/langchain_benchmarks/registration.py b/langchain_benchmarks/registration.py index 7ebf8fed..f68d72e1 100644 --- a/langchain_benchmarks/registration.py +++ b/langchain_benchmarks/registration.py @@ -1,6 +1,6 @@ """Registry of environments for ease of access.""" -from langchain_benchmarks.extraction import email_task +from langchain_benchmarks.extraction.tasks import email_task from langchain_benchmarks.schema import Registry from langchain_benchmarks.tool_usage.tasks import ( type_writer, diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index eac59c47..17370306 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -2,6 +2,7 @@ import dataclasses from typing import List, Callable, Any, Optional, Type, Union +from langchain.prompts import ChatPromptTemplate from langchain.tools import BaseTool from pydantic import BaseModel from tabulate import tabulate @@ -68,8 +69,16 @@ class ToolUsageTask(BaseTask): class ExtractionTask(BaseTask): """A definition for an extraction task.""" - model: Type[BaseModel] - """Get the model for the task.""" + schema: Type[BaseModel] + """Get schema that specifies what should be extracted.""" + + # We might want to make this optional / or support more types + # and add validation, but let's wait until we have more examples + instructions: ChatPromptTemplate + """Get the prompt for the task. + + This is the default prompt to use for the task. + """ @dataclasses.dataclass(frozen=False) diff --git a/tests/unit_tests/extraction/test_email_extraction.py b/tests/unit_tests/extraction/test_email_extraction.py index 57bf3530..e9622150 100644 --- a/tests/unit_tests/extraction/test_email_extraction.py +++ b/tests/unit_tests/extraction/test_email_extraction.py @@ -1,3 +1,2 @@ def test_email_extraction() -> None: """Try to import the email task.""" - from langchain_benchmarks.extraction import email_task # noqa: F401 diff --git a/tests/unit_tests/extraction/test_import_stuff.py b/tests/unit_tests/extraction/test_import_stuff.py new file mode 100644 index 00000000..363340bc --- /dev/null +++ b/tests/unit_tests/extraction/test_import_stuff.py @@ -0,0 +1,3 @@ +def test_import_stuff() -> None: + """Test that all imports work.""" + from langchain_benchmarks.extraction import evaluators, implementations # noqa: F401