diff --git a/langchain_benchmarks/__init__.py b/langchain_benchmarks/__init__.py index 7b42085f..0bdaa0ac 100644 --- a/langchain_benchmarks/__init__.py +++ b/langchain_benchmarks/__init__.py @@ -1,4 +1,7 @@ -from .utils._langsmith import clone_dataset +from langchain_benchmarks.utils._langsmith import ( + clone_public_dataset, + download_public_dataset, +) # Please keep this list sorted! -__all__ = ["clone_dataset"] +__all__ = ["clone_public_dataset", "download_public_dataset"] diff --git a/langchain_benchmarks/utils/_langsmith.py b/langchain_benchmarks/utils/_langsmith.py index 7645a23f..53eb0014 100644 --- a/langchain_benchmarks/utils/_langsmith.py +++ b/langchain_benchmarks/utils/_langsmith.py @@ -1,14 +1,41 @@ """Copy the public dataset to your own langsmith tenant.""" +import json +import urllib.parse +from pathlib import Path +from typing import Union, Optional, Tuple +from uuid import UUID + from langsmith import Client from langsmith.utils import LangSmithNotFoundError -from tqdm import tqdm +from tqdm import auto + +WEB_API_URL = "https://web.smith.langchain.com/" + + +def _parse_token_or_url(url_or_token: str, api_url: str) -> Tuple[str, Optional[str]]: + """Parse a public dataset URL or share token.""" + try: + UUID(url_or_token) + return api_url, url_or_token + except ValueError: + pass + + # Then it's a URL + parsed_url = urllib.parse.urlparse(url_or_token) + # Extract the UUID from the path + path_parts = parsed_url.path.split("/") + uuid = path_parts[-2] if len(path_parts) >= 2 else None + return WEB_API_URL, uuid + # PUBLIC API -def clone_dataset( - public_dataset_token: str, - dataset_name: str, +def clone_public_dataset( + token_or_url: str, + *, + dataset_name: Optional[str] = None, + source_api_url: str = WEB_API_URL, ) -> None: """Clone a public dataset to your own langsmith tenant. @@ -16,39 +43,80 @@ def clone_dataset( this function will do nothing. Args: - public_dataset_token (str): The token of the public dataset to clone. + token_or_url (str): The token of the public dataset to clone. dataset_name (str): The name of the dataset to create in your tenant. + source_api_url: The URL of the langsmith server where the data is hosted:w """ - client = Client() - + if dataset_name is None: + raise NotImplementedError( + "Automatic dataset name generation is not implemented yet" + ) + client = Client() # Client used to write to langsmith try: - client.read_dataset(dataset_name=dataset_name) + dataset = client.read_dataset(dataset_name=dataset_name) + + if dataset: + print(f"Dataset {dataset_name} already exists. Skipping.") + print(f"You can access the dataset at {dataset.url}.") + return except LangSmithNotFoundError: pass - else: - print(f"Dataset {dataset_name} already exists. Skipping.") - return - - # Fetch examples first - examples = tqdm(list(client.list_shared_examples(public_dataset_token))) - print("Finished fetching examples. Creating dataset...") - dataset = client.create_dataset(dataset_name=dataset_name) + + source_api_url, uuid = _parse_token_or_url(token_or_url, source_api_url) + source_client = Client(api_url=source_api_url, api_key="placeholder") try: - client.create_examples( - inputs=[e.inputs for e in examples], - outputs=[e.outputs for e in examples], - dataset_id=dataset.id, - ) - except BaseException as e: - # Let's not do automatic clean up for now in case there might be - # some other reasons why create_examples fails (i.e., not network issue or - # keyboard interrupt). - # The risk is that this is an existing dataset that has valid examples - # populated from another source so we don't want to delete it. - print( - f"An error occurred while creating dataset {dataset_name}. " - "You should delete it manually." - ) - raise e + # Fetch examples first + examples = auto.tqdm(list(source_client.list_shared_examples(uuid))) + print("Finished fetching examples. Creating dataset...") + dataset = client.create_dataset(dataset_name=dataset_name) + print(f"New dataset created you can access it at {dataset.url}.") + try: + client.create_examples( + inputs=[e.inputs for e in examples], + outputs=[e.outputs for e in examples], + dataset_id=dataset.id, + ) + except BaseException as e: + # Let's not do automatic clean up for now in case there might be + # some other reasons why create_examples fails (i.e., not network issue or + # keyboard interrupt). + # The risk is that this is an existing dataset that has valid examples + # populated from another source so we don't want to delete it. + print( + f"An error occurred while creating dataset {dataset_name}. " + "You should delete it manually." + ) + raise e + + print("Done creating dataset.") + finally: + del source_client + del client - print("Done creating dataset.") + +def download_public_dataset( + token_or_url: str, + *, + path: Optional[Union[str, Path]] = None, + api_url: str = WEB_API_URL, +) -> None: + """Download a public dataset.""" + api_url, uuid = _parse_token_or_url(token_or_url, api_url) + _path = str(path) if path else f"{uuid}.json" + if not _path.endswith(".json"): + raise ValueError(f"Path must end with .json got: {_path}") + + # This the client where the source data lives + # The destination for the dataset is the local filesystem + source_client = Client(api_url=api_url, api_key="placeholder") + + try: + # Fetch examples first + print("Fetching examples...") + examples = auto.tqdm(list(source_client.list_shared_examples(uuid))) + with open(str(_path), mode="w", encoding="utf-8") as f: + jsonifable_examples = [json.loads(example.json()) for example in examples] + json.dump(jsonifable_examples, f, indent=2) + print("Done fetching examples.") + finally: + del source_client diff --git a/poetry.lock b/poetry.lock index 7cc06f82..854bfc70 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1071,6 +1071,27 @@ files = [ {file = "ipython_genutils-0.2.0.tar.gz", hash = "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8"}, ] +[[package]] +name = "ipywidgets" +version = "8.1.1" +description = "Jupyter interactive widgets" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"}, + {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"}, +] + +[package.dependencies] +comm = ">=0.1.3" +ipython = ">=6.1.0" +jupyterlab-widgets = ">=3.0.9,<3.1.0" +traitlets = ">=4.3.1" +widgetsnbextension = ">=4.0.9,<4.1.0" + +[package.extras] +test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] + [[package]] name = "isoduration" version = "20.11.0" @@ -1452,6 +1473,17 @@ docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pyd openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"] test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] +[[package]] +name = "jupyterlab-widgets" +version = "3.0.9" +description = "Jupyter interactive widgets for JupyterLab" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"}, + {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"}, +] + [[package]] name = "langchain" version = "0.0.336" @@ -3327,6 +3359,17 @@ docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] optional = ["python-socks", "wsaccel"] test = ["websockets"] +[[package]] +name = "widgetsnbextension" +version = "4.0.9" +description = "Jupyter interactive widgets for Jupyter Notebook" +optional = false +python-versions = ">=3.7" +files = [ + {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, + {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, +] + [[package]] name = "y-py" version = "0.6.2" @@ -3534,4 +3577,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "cce4d8c816ccfa58210fedc8a161a960cd4554777ed30a5da9e236be5c7c87db" +content-hash = "83e140fae605ab8da7d9259b93fb9648ecabd73a20f766647793b4dcc6287d37" diff --git a/pyproject.toml b/pyproject.toml index 174f0463..13c3abac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-benchmarks" -version = "0.1.0" +version = "0.0.1" description = "Flex them feathers! 🦜💪" authors = ["LangChain AI"] license = "MIT" @@ -11,6 +11,7 @@ python = "^3.8.1" langchain = ">=0.0.333" langsmith = "^0.0.64" tqdm = "^4.66.1" +ipywidgets = "^8.1.1" [tool.poetry.group.dev.dependencies] jupyterlab = "^3.6.1" diff --git a/tests/unit_tests/test_public_api.py b/tests/unit_tests/test_public_api.py index c549a48a..31251f14 100644 --- a/tests/unit_tests/test_public_api.py +++ b/tests/unit_tests/test_public_api.py @@ -5,8 +5,4 @@ def test_public_api() -> None: """Test that the public API is correct.""" # This test will also fail if __all__ is not sorted. # Please keep it sorted! - assert __all__ == sorted( - [ - "clone_dataset", - ] - ) + assert __all__ == sorted(["clone_public_dataset", "download_public_dataset"]) diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index 05fd6de6..e08aea60 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -9,7 +9,7 @@ from langsmith.schemas import Dataset, Example from langsmith.utils import LangSmithNotFoundError -from langchain_benchmarks.utils._langsmith import clone_dataset +from langchain_benchmarks.utils._langsmith import clone_public_dataset # Define a mock Client class that overrides the required methods @@ -95,10 +95,10 @@ def test_clone_dataset() -> None: dataset_name = "my_dataset" with mock_langsmith_client() as mock_client: - clone_dataset(public_dataset_token, dataset_name) + clone_public_dataset(public_dataset_token, dataset_name=dataset_name) assert mock_client.datasets[0].name == dataset_name assert len(mock_client.examples) == 2 # Check idempotency - clone_dataset(public_dataset_token, dataset_name) + clone_public_dataset(public_dataset_token, dataset_name=dataset_name) assert len(mock_client.examples) == 2