diff --git a/lilac/data/clustering.py b/lilac/data/clustering.py index a9762372f..7530aafe4 100644 --- a/lilac/data/clustering.py +++ b/lilac/data/clustering.py @@ -104,17 +104,7 @@ def cluster_impl( if output_path: cluster_output_path = normalize_path(output_path) elif path: - # The sibling output path is the same as the input path, but with a different suffix. - index = 0 - for i, path_part in enumerate(path): - if path_part == PATH_WILDCARD: - break - else: - index = i - - parent = path[:index] - sibling = '_'.join([p for p in path[index:] if p != PATH_WILDCARD]) - cluster_output_path = (*parent, f'{sibling}__{FIELD_SUFFIX}') + cluster_output_path = default_cluster_output_path(path) else: raise ValueError('input must be provided.') @@ -416,3 +406,19 @@ def _hdbscan_cluster( for cluster_id, membership_prob in zip(labels, memberships): yield {CLUSTER_ID: int(cluster_id), CLUSTER_MEMBERSHIP_PROB: float(membership_prob)} + + +def default_cluster_output_path(input_path: Path) -> PathTuple: + """Default output path for clustering.""" + input_path = normalize_path(input_path) + # The sibling output path is the same as the input path, but with a different suffix. + index = 0 + for i, path_part in enumerate(input_path): + if path_part == PATH_WILDCARD: + break + else: + index = i + + parent = input_path[:index] + sibling = '_'.join([p for p in input_path[index:] if p != PATH_WILDCARD]) + return (*parent, f'{sibling}__{FIELD_SUFFIX}') diff --git a/lilac/formats/openai_json.py b/lilac/formats/openai_json.py index 3b861daec..47a198047 100644 --- a/lilac/formats/openai_json.py +++ b/lilac/formats/openai_json.py @@ -32,7 +32,7 @@ class OpenAIJSON(DatasetFormat): Taken from: https://platform.openai.com/docs/api-reference/chat """ - name: ClassVar[str] = 'openai_json' + name: ClassVar[str] = 'OpenAI JSON' data_schema: Schema = schema( { 'messages': [ @@ -88,7 +88,7 @@ class OpenAIConversationJSON(DatasetFormat): Note that here "messages" is "conversation" for support with common datasets. """ - name: ClassVar[str] = 'openai_conversation_json' + name: ClassVar[str] = 'OpenAI Conversation JSON' data_schema: Schema = schema( { 'conversation': [ diff --git a/lilac/formats/openchat.py b/lilac/formats/openchat.py index 815268e0d..9bee2ee32 100644 --- a/lilac/formats/openchat.py +++ b/lilac/formats/openchat.py @@ -10,7 +10,7 @@ class OpenChat(DatasetFormat): """OpenChat format.""" - name: ClassVar[str] = 'openchat' + name: ClassVar[str] = 'OpenChat' data_schema: Schema = schema( { 'items': [ diff --git a/lilac/formats/sharegpt.py b/lilac/formats/sharegpt.py index 75c9f2e91..30134205d 100644 --- a/lilac/formats/sharegpt.py +++ b/lilac/formats/sharegpt.py @@ -37,7 +37,7 @@ def _sharegpt_selector(item: Item, conv_from: str) -> str: class ShareGPT(DatasetFormat): """ShareGPT format.""" - name: ClassVar[str] = 'sharegpt' + name: ClassVar[str] = 'ShareGPT' data_schema: Schema = schema( { 'conversations': [ @@ -59,5 +59,5 @@ class ShareGPT(DatasetFormat): input_selectors: ClassVar[dict[str, DatasetFormatInputSelector]] = { selector.name: selector - for selector in [_SYSTEM_SELECTOR, _HUMAN_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR] + for selector in [_HUMAN_SELECTOR, _SYSTEM_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR] } diff --git a/lilac/load_test.py b/lilac/load_test.py index 0c3ae2895..1ac93e244 100644 --- a/lilac/load_test.py +++ b/lilac/load_test.py @@ -513,7 +513,7 @@ def _test_topic_fn(docs: list[tuple[str, float]]) -> str: dataset_namespace='namespace', dataset_name='test', input_selector=ClusterInputSelectorConfig( - format='sharegpt', + format='ShareGPT', selector='human', ), output_path=('cluster',), diff --git a/lilac/router_dataset.py b/lilac/router_dataset.py index fbef6181f..76a0b08e2 100644 --- a/lilac/router_dataset.py +++ b/lilac/router_dataset.py @@ -534,3 +534,13 @@ def restore_rows( searches=options.searches, filters=sanitized_filters, ) + + +@router.get('/{namespace}/{dataset_name}/format_selectors') +def get_format_selectors(namespace: str, dataset_name: str) -> list[str]: + """Get format selectors for the dataset if a format has been inferred.""" + dataset = get_dataset(namespace, dataset_name) + manifest = dataset.manifest() + if manifest.dataset_format: + return list(manifest.dataset_format.input_selectors.keys()) + return [] diff --git a/lilac/router_dataset_signals.py b/lilac/router_dataset_signals.py index 7c90335cc..4c4f4d8df 100644 --- a/lilac/router_dataset_signals.py +++ b/lilac/router_dataset_signals.py @@ -1,5 +1,5 @@ """Routing endpoints for running signals on datasets.""" -from typing import Annotated, Optional +from typing import Annotated, Optional, Union from fastapi import APIRouter, HTTPException from fastapi.params import Depends @@ -7,9 +7,11 @@ from pydantic import Field as PydanticField from .auth import UserInfo, get_session_user, get_user_access +from .data.clustering import default_cluster_output_path +from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls from .db_manager import get_dataset from .router_utils import RouteErrorHandler -from .schema import Path +from .schema import Path, PathTuple, normalize_path from .signal import Signal, resolve_signal from .tasks import TaskId, get_task_manager, launch_task @@ -82,7 +84,9 @@ def run() -> None: class ClusterOptions(BaseModel): """The request for the cluster endpoint.""" - input: Path + input: Optional[Path] = None + input_selector: Optional[str] = None + output_path: Optional[Path] = None use_garden: bool = PydanticField( default=False, description='Accelerate computation by running remotely on Lilac Garden.' @@ -107,14 +111,36 @@ def cluster( if not get_user_access(user).dataset.compute_signals: raise HTTPException(401, 'User does not have access to compute clusters over this dataset.') - path_str = '.'.join(map(str, options.input)) - task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"' - task_id = get_task_manager().task_id(name=task_name) dataset = get_dataset(namespace, dataset_name) + manifest = dataset.manifest() + + cluster_input: Optional[Union[DatasetFormatInputSelector, PathTuple]] = None + if options.input: + path_str = '.'.join(map(str, options.input)) + task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"' + cluster_input = normalize_path(options.input) + elif options.input_selector: + dataset_format = manifest.dataset_format + if dataset_format is None: + raise ValueError('Dataset format is not defined.') + + format_cls = get_dataset_format_cls(dataset_format.name) + if format_cls is None: + raise ValueError(f'Unknown format: {dataset_format.name}') + + cluster_input = format_cls.input_selectors[options.input_selector] + + task_name = ( + f'[{namespace}/{dataset_name}] Clustering using input selector ' f'"{options.input_selector}"' + ) + else: + raise HTTPException(400, 'Either input or input_selector must be provided.') + + task_id = get_task_manager().task_id(name=task_name) def run() -> None: dataset.cluster( - options.input, + cluster_input, options.output_path, use_garden=options.use_garden, overwrite=options.overwrite, @@ -125,6 +151,18 @@ def run() -> None: return ClusterResponse(task_id=task_id) +class DefaultClusterOutputPathOptions(BaseModel): + """Request body for the default cluster output path endpoint.""" + + input_path: Path + + +@router.post('/{namespace}/{dataset_name}/default_cluster_output_path') +def get_default_cluster_output_path(options: DefaultClusterOutputPathOptions) -> Path: + """Get format selectors for the dataset if a format has been inferred.""" + return default_cluster_output_path(options.input_path) + + class DeleteSignalOptions(BaseModel): """The request for the delete signal endpoint.""" diff --git a/web/blueprint/src/lib/components/ComputeClusterModal.svelte b/web/blueprint/src/lib/components/ComputeClusterModal.svelte index 1763643c8..1380a7aff 100644 --- a/web/blueprint/src/lib/components/ComputeClusterModal.svelte +++ b/web/blueprint/src/lib/components/ComputeClusterModal.svelte @@ -18,14 +18,21 @@