diff --git a/langchain_benchmarks/__init__.py b/langchain_benchmarks/__init__.py index e69de29b..7b42085f 100644 --- a/langchain_benchmarks/__init__.py +++ b/langchain_benchmarks/__init__.py @@ -0,0 +1,4 @@ +from .utils._langsmith import clone_dataset + +# Please keep this list sorted! +__all__ = ["clone_dataset"] diff --git a/langchain_benchmarks/utils/__init__.py b/langchain_benchmarks/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain_benchmarks/utils/_langsmith.py b/langchain_benchmarks/utils/_langsmith.py new file mode 100644 index 00000000..7645a23f --- /dev/null +++ b/langchain_benchmarks/utils/_langsmith.py @@ -0,0 +1,54 @@ +"""Copy the public dataset to your own langsmith tenant.""" +from langsmith import Client +from langsmith.utils import LangSmithNotFoundError +from tqdm import tqdm + +# PUBLIC API + + +def clone_dataset( + public_dataset_token: str, + dataset_name: str, +) -> None: + """Clone a public dataset to your own langsmith tenant. + + This operation is idempotent. If you already have a dataset with the given name, + this function will do nothing. + + Args: + public_dataset_token (str): The token of the public dataset to clone. + dataset_name (str): The name of the dataset to create in your tenant. + """ + client = Client() + + try: + client.read_dataset(dataset_name=dataset_name) + except LangSmithNotFoundError: + pass + else: + print(f"Dataset {dataset_name} already exists. Skipping.") + return + + # Fetch examples first + examples = tqdm(list(client.list_shared_examples(public_dataset_token))) + print("Finished fetching examples. Creating dataset...") + dataset = client.create_dataset(dataset_name=dataset_name) + try: + client.create_examples( + inputs=[e.inputs for e in examples], + outputs=[e.outputs for e in examples], + dataset_id=dataset.id, + ) + except BaseException as e: + # Let's not do automatic clean up for now in case there might be + # some other reasons why create_examples fails (i.e., not network issue or + # keyboard interrupt). + # The risk is that this is an existing dataset that has valid examples + # populated from another source so we don't want to delete it. + print( + f"An error occurred while creating dataset {dataset_name}. " + "You should delete it manually." + ) + raise e + + print("Done creating dataset.") diff --git a/poetry.lock b/poetry.lock index a0b7e070..7cc06f82 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3133,6 +3133,26 @@ files = [ {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"}, ] +[[package]] +name = "tqdm" +version = "4.66.1" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, + {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "traitlets" version = "5.13.0" @@ -3514,4 +3534,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "d03b1398809fc7aa91d1bbdc41a6e0e04992a91b69d51bc1362af43a0eba632c" +content-hash = "cce4d8c816ccfa58210fedc8a161a960cd4554777ed30a5da9e236be5c7c87db" diff --git a/pyproject.toml b/pyproject.toml index 8e3309ab..174f0463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.8.1" langchain = ">=0.0.333" +langsmith = "^0.0.64" +tqdm = "^4.66.1" [tool.poetry.group.dev.dependencies] jupyterlab = "^3.6.1" diff --git a/tests/unit_tests/test_public_api.py b/tests/unit_tests/test_public_api.py new file mode 100644 index 00000000..c549a48a --- /dev/null +++ b/tests/unit_tests/test_public_api.py @@ -0,0 +1,12 @@ +from langchain_benchmarks import __all__ + + +def test_public_api() -> None: + """Test that the public API is correct.""" + # This test will also fail if __all__ is not sorted. + # Please keep it sorted! + assert __all__ == sorted( + [ + "clone_dataset", + ] + ) diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py new file mode 100644 index 00000000..05fd6de6 --- /dev/null +++ b/tests/unit_tests/test_utils.py @@ -0,0 +1,104 @@ +import datetime +import unittest.mock as mock +import uuid +from contextlib import contextmanager +from typing import Any, Generator, List, Mapping, Optional, Sequence +from uuid import UUID + +from langsmith.client import ID_TYPE +from langsmith.schemas import Dataset, Example +from langsmith.utils import LangSmithNotFoundError + +from langchain_benchmarks.utils._langsmith import clone_dataset + + +# Define a mock Client class that overrides the required methods +class MockLangSmithClient: + def __init__(self) -> None: + """Initialize the mock client.""" + self.datasets = [] + self.examples = [] + + def read_dataset(self, dataset_name: str) -> Dataset: + for dataset in self.datasets: + if dataset.name == dataset_name: + return dataset + raise LangSmithNotFoundError(f'Dataset "{dataset_name}" not found.') + + def create_dataset(self, dataset_name: str) -> Dataset: + # Simulate creating a dataset and returning a mock Dataset object + dataset = Dataset( + id=UUID(int=3), name=dataset_name, created_at=datetime.datetime(2021, 1, 1) + ) + self.datasets.append(dataset) + return dataset + + def create_examples( + self, + *, + inputs: Sequence[Mapping[str, Any]], + outputs: Optional[Sequence[Optional[Mapping[str, Any]]]] = None, + dataset_id: Optional[ID_TYPE] = None, + dataset_name: Optional[str] = None, + max_concurrency: int = 10, + ) -> None: + """Create examples""" + examples = [] + for idx, (input, output) in enumerate(zip(inputs, outputs)): + examples.append( + Example( + id=UUID(int=idx), + inputs=input, + outputs=output, + created_at=datetime.datetime(2021, 1, 1), + dataset_id=dataset_id, + dataset_name=dataset_name, + ) + ) + + return self.examples.extend(examples) + + def list_shared_examples(self, public_dataset_token: str) -> List[Example]: + # Simulate fetching shared examples and returning a list of Example objects + example1 = Example( + id=UUID(int=1), + inputs={"a": 1}, + outputs={}, + created_at=datetime.datetime(2021, 1, 1), + dataset_id=public_dataset_token, + ) + example2 = Example( + id=UUID(int=2), + inputs={"b": 2}, + outputs={}, + created_at=datetime.datetime(2021, 1, 1), + dataset_id=public_dataset_token, + ) + return [example1, example2] + + +@contextmanager +def mock_langsmith_client() -> Generator[None, None, None]: + """Mock the langsmith Client class.""" + from langchain_benchmarks.utils import _langsmith + + mock_client = MockLangSmithClient() + + with mock.patch.object(_langsmith, "Client") as client: + client.return_value = mock_client + yield mock_client + + +def test_clone_dataset() -> None: + # Call the clone_dataset function with mock data + public_dataset_token = str(uuid.UUID(int=3)) + dataset_name = "my_dataset" + + with mock_langsmith_client() as mock_client: + clone_dataset(public_dataset_token, dataset_name) + assert mock_client.datasets[0].name == dataset_name + assert len(mock_client.examples) == 2 + + # Check idempotency + clone_dataset(public_dataset_token, dataset_name) + assert len(mock_client.examples) == 2