From b91e496359fef5d49512536b8d227683571111b4 Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Thu, 15 Feb 2024 11:15:24 -0500 Subject: [PATCH 1/7] formats --- lilac/formats/sharegpt.py | 2 +- lilac/router_dataset.py | 10 ++ lilac/router_dataset_signals.py | 49 ++++++- notebooks/Clustering copy.ipynb | 120 ++++++++++++++++++ .../lib/components/ComputeClusterModal.svelte | 119 ++++++++++++----- .../src/lib/queries/datasetQueries.ts | 4 + web/lib/fastapi_client/index.ts | 1 + .../models/ClusterInputSelectorConfig.ts | 13 ++ .../fastapi_client/models/ClusterOptions.ts | 5 +- .../services/DatasetsService.ts | 25 ++++ 10 files changed, 304 insertions(+), 44 deletions(-) create mode 100644 notebooks/Clustering copy.ipynb create mode 100644 web/lib/fastapi_client/models/ClusterInputSelectorConfig.ts diff --git a/lilac/formats/sharegpt.py b/lilac/formats/sharegpt.py index 75c9f2e91..49a6b1038 100644 --- a/lilac/formats/sharegpt.py +++ b/lilac/formats/sharegpt.py @@ -59,5 +59,5 @@ class ShareGPT(DatasetFormat): input_selectors: ClassVar[dict[str, DatasetFormatInputSelector]] = { selector.name: selector - for selector in [_SYSTEM_SELECTOR, _HUMAN_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR] + for selector in [_HUMAN_SELECTOR, _SYSTEM_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR] } diff --git a/lilac/router_dataset.py b/lilac/router_dataset.py index fbef6181f..76a0b08e2 100644 --- a/lilac/router_dataset.py +++ b/lilac/router_dataset.py @@ -534,3 +534,13 @@ def restore_rows( searches=options.searches, filters=sanitized_filters, ) + + +@router.get('/{namespace}/{dataset_name}/format_selectors') +def get_format_selectors(namespace: str, dataset_name: str) -> list[str]: + """Get format selectors for the dataset if a format has been inferred.""" + dataset = get_dataset(namespace, dataset_name) + manifest = dataset.manifest() + if manifest.dataset_format: + return list(manifest.dataset_format.input_selectors.keys()) + return [] diff --git a/lilac/router_dataset_signals.py b/lilac/router_dataset_signals.py index 7c90335cc..b20455a60 100644 --- a/lilac/router_dataset_signals.py +++ b/lilac/router_dataset_signals.py @@ -1,5 +1,5 @@ """Routing endpoints for running signals on datasets.""" -from typing import Annotated, Optional +from typing import Annotated, Optional, Union from fastapi import APIRouter, HTTPException from fastapi.params import Depends @@ -7,9 +7,11 @@ from pydantic import Field as PydanticField from .auth import UserInfo, get_session_user, get_user_access +from .config import ClusterInputSelectorConfig +from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls from .db_manager import get_dataset from .router_utils import RouteErrorHandler -from .schema import Path +from .schema import Path, PathTuple, normalize_path from .signal import Signal, resolve_signal from .tasks import TaskId, get_task_manager, launch_task @@ -82,7 +84,9 @@ def run() -> None: class ClusterOptions(BaseModel): """The request for the cluster endpoint.""" - input: Path + input: Optional[Path] = None + input_selector: Optional[ClusterInputSelectorConfig] = None + output_path: Optional[Path] = None use_garden: bool = PydanticField( default=False, description='Accelerate computation by running remotely on Lilac Garden.' @@ -107,14 +111,45 @@ def cluster( if not get_user_access(user).dataset.compute_signals: raise HTTPException(401, 'User does not have access to compute clusters over this dataset.') - path_str = '.'.join(map(str, options.input)) - task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"' - task_id = get_task_manager().task_id(name=task_name) + if options.input is None and options.input_selector is None: + raise HTTPException(400, 'Either input or input_selector must be provided.') + dataset = get_dataset(namespace, dataset_name) + manifest = dataset.manifest() + + cluster_input: Optional[Union[DatasetFormatInputSelector, PathTuple]] = None + if options.input: + path_str = '.'.join(map(str, options.input)) + task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"' + cluster_input = normalize_path(options.input) + elif options.input_selector: + dataset_format = manifest.dataset_format + if dataset_format is None: + raise ValueError('Dataset format is not defined.') + + format_cls = get_dataset_format_cls(dataset_format.name) + if format_cls is None: + raise ValueError(f'Unknown format: {c.input_selector.format}') + + format = format_cls() + if format != manifest.dataset_format: + raise ValueError( + f'Cluster input format {c.input_selector.format} does not match ' + f'dataset format {manifest.dataset_format}' + ) + + cluster_input = format_cls.input_selectors[c.input_selector.selector] + + task_name = ( + f'[{namespace}/{dataset_name}] Clustering using input selector ' + f'"{options.input_selector.selector}"' + ) + + task_id = get_task_manager().task_id(name=task_name) def run() -> None: dataset.cluster( - options.input, + cluster_input, options.output_path, use_garden=options.use_garden, overwrite=options.overwrite, diff --git a/notebooks/Clustering copy.ipynb b/notebooks/Clustering copy.ipynb new file mode 100644 index 000000000..84a3ffbca --- /dev/null +++ b/notebooks/Clustering copy.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Clustering\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook accompanies the [Cluster a dataset](https://docs.lilacml.com/datasets/dataset_cluster.html) guide.\n", + "Let's start by loading a small dataset of multi-turn conversations between a human and a chatbot:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset \"capybara\" written to ./datasets/local/capybara\n" + ] + } + ], + "source": [ + "import lilac as ll\n", + "\n", + "ds = ll.get_dataset('local', 'OpenHermes-2.5-100k')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can cluster the `input` field under the `conversation` array by calling:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[local/capybara][1 shards] map \"extract_text\" to \"('conversation_input__cluster',)\": 100%|██████████| 16006/16006 [00:00<00:00, 30424.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wrote map output to conversation_input__cluster-00000-of-00001.parquet\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[local/capybara][1 shards] map \"compute_clusters\" to \"('conversation_input__cluster',)\": 0%| | 0/16006 [00:00