Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more extraction code #37

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions langchain_benchmarks/extraction/evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from langchain.smith import RunEvalConfig
from pydantic import BaseModel


def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig:
"""Get the evaluation configuration for the email task."""
return RunEvalConfig(
evaluators=[
RunEvalConfig.LabeledScoreString(
criteria={
"accuracy": """
Score 1: The answer is incorrect and unrelated to the question or reference document.
Score 3: The answer is partially correct but has more than one omission or major errors.
Score 5: The answer is mostly correct but has more than one omission or major error.
Score 7: The answer is mostly correct but has at most one omission or major error.
Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document.
Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible.

If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct.
If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized.
""" # noqa
},
llm=eval_llm,
normalize_by=10.0,
),
],
)
67 changes: 67 additions & 0 deletions langchain_benchmarks/extraction/implementations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Default implementations of LLMs that can be used for extraction."""
from typing import Type, Optional, List, Any, Dict

from langchain.chains.openai_functions import convert_to_openai_function
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.runnable import Runnable
from langsmith.client import Client
from pydantic import BaseModel

from langchain_benchmarks.extraction.evaluators import get_eval_config
from langchain_benchmarks.schema import ExtractionTask

# PUBLIC API

def create_openai_function_based_extractor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc: why are these diff than LangChain?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are they in langchain? I just cargo culted some code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well there was create_extraction_chain but i guess we no longer use it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will review right now if there's a way to re-use it

llm: Runnable,
schema: Type[BaseModel],
) -> Runnable[dict, dict]:
"""Create an extraction chain that uses an LLM to extract a schema.

The underlying functionality is exclusively for LLMs that support
extraction using openai functions format.

Args:
llm: The LLM to use for extraction.
schema: The schema to extract.

Returns:
An llm that will extract the schema
"""
openai_functions = [convert_to_openai_function(schema)]
llm_kwargs = {
"functions": openai_functions,
"function_call": {"name": openai_functions[0]["name"]},
}
output_parser = JsonOutputFunctionsParser()
extraction_chain = (
llm.bind(**llm_kwargs) | output_parser | (lambda x: {"output": x})
)
return extraction_chain


def run_on_dataset(
task: ExtractionTask,
llm: Runnable,
*,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Run an LLM on a dataset.

Args:
task: The task to run on.
llm: The LLM to run.
tags: The tags to use for the run.
kwargs: Additional arguments to pass to the client.
"""
client = Client()
eval_llm = ChatOpenAI(model="gpt-4", temperature=0.0, model_kwargs={"seed": 42})
return client.run_on_dataset(
dataset_name=task.name,
llm_or_chain_factory=create_openai_function_based_extractor(llm, task.schema),
evaluation=get_eval_config(eval_llm),
tags=tags,
**kwargs,
)
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Optional, List

from langchain.smith import RunEvalConfig
from langchain.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field

from langchain_benchmarks.schema import ExtractionTask
Expand Down Expand Up @@ -33,35 +33,22 @@ class Email(BaseModel):
tone: ToneEnum = Field(..., description="The tone of the email.")


def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig:
"""Get the evaluation configuration for the email task."""
return RunEvalConfig(
evaluators=[
RunEvalConfig.LabeledScoreString(
criteria={
"accuracy": """
Score 1: The answer is incorrect and unrelated to the question or reference document.
Score 3: The answer is partially correct but has more than one omission or major errors.
Score 5: The answer is mostly correct but has more than one omission or major error.
Score 7: The answer is mostly correct but has at most one omission or major error.
Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document.
Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible.

If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct.
If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized.
""" # noqa
},
llm=eval_llm,
normalize_by=10.0,
),
],
)

# This is a default prompt that works for chat models.
DEFAULT_CHAT_MODEL_PROMPT = ChatPromptTemplate.from_messages(
[
("system", "You are an expert researcher."),
(
"human",
"What can you tell me about the following email? Make sure to "
"answer in the correct format: {schema}",
),
]
)

EMAIL_EXTRACTION_TASK = ExtractionTask(
name="Email Extraction",
dataset_id="https://smith.langchain.com/public/36bdfe7d-3cd1-4b36-b957-d12d95810a2b/d",
model=Email,
schema=Email,
description="""\
A dataset of 42 real emails deduped from a spam folder, with semantic HTML tags removed, \
as well as a script for initial extraction and formatting of other emails from \
Expand All @@ -71,4 +58,5 @@ def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig:

See https://github.com/jacoblee93/oss-model-extraction-evals.
""",
instructions=DEFAULT_CHAT_MODEL_PROMPT,
)
2 changes: 1 addition & 1 deletion langchain_benchmarks/registration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Registry of environments for ease of access."""

from langchain_benchmarks.extraction import email_task
from langchain_benchmarks.extraction.tasks import email_task
from langchain_benchmarks.schema import Registry
from langchain_benchmarks.tool_usage.tasks import (
type_writer,
Expand Down
13 changes: 11 additions & 2 deletions langchain_benchmarks/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
from typing import List, Callable, Any, Optional, Type, Union

from langchain.prompts import ChatPromptTemplate
from langchain.tools import BaseTool
from pydantic import BaseModel
from tabulate import tabulate
Expand Down Expand Up @@ -68,8 +69,16 @@ class ToolUsageTask(BaseTask):
class ExtractionTask(BaseTask):
"""A definition for an extraction task."""

model: Type[BaseModel]
"""Get the model for the task."""
schema: Type[BaseModel]
"""Get schema that specifies what should be extracted."""

# We might want to make this optional / or support more types
# and add validation, but let's wait until we have more examples
instructions: ChatPromptTemplate
"""Get the prompt for the task.

This is the default prompt to use for the task.
"""


@dataclasses.dataclass(frozen=False)
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion tests/unit_tests/extraction/test_email_extraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
def test_email_extraction() -> None:
"""Try to import the email task."""
from langchain_benchmarks.extraction import email_task # noqa: F401
3 changes: 3 additions & 0 deletions tests/unit_tests/extraction/test_import_stuff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test_import_stuff() -> None:
"""Test that all imports work."""
from langchain_benchmarks.extraction import evaluators, implementations # noqa: F401
Loading