Skip to content

Commit

Permalink
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 62 deletions.
28 changes: 17 additions & 11 deletions lilac/data/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,7 @@ def cluster_impl(
if output_path:
cluster_output_path = normalize_path(output_path)
elif path:
# The sibling output path is the same as the input path, but with a different suffix.
index = 0
for i, path_part in enumerate(path):
if path_part == PATH_WILDCARD:
break
else:
index = i

parent = path[:index]
sibling = '_'.join([p for p in path[index:] if p != PATH_WILDCARD])
cluster_output_path = (*parent, f'{sibling}__{FIELD_SUFFIX}')
cluster_output_path = default_cluster_output_path(path)
else:
raise ValueError('input must be provided.')

Expand Down Expand Up @@ -416,3 +406,19 @@ def _hdbscan_cluster(

for cluster_id, membership_prob in zip(labels, memberships):
yield {CLUSTER_ID: int(cluster_id), CLUSTER_MEMBERSHIP_PROB: float(membership_prob)}


def default_cluster_output_path(input_path: Path) -> PathTuple:
"""Default output path for clustering."""
input_path = normalize_path(input_path)
# The sibling output path is the same as the input path, but with a different suffix.
index = 0
for i, path_part in enumerate(input_path):
if path_part == PATH_WILDCARD:
break
else:
index = i

parent = input_path[:index]
sibling = '_'.join([p for p in input_path[index:] if p != PATH_WILDCARD])
return (*parent, f'{sibling}__{FIELD_SUFFIX}')
4 changes: 2 additions & 2 deletions lilac/formats/openai_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OpenAIJSON(DatasetFormat):
Taken from: https://platform.openai.com/docs/api-reference/chat
"""

name: ClassVar[str] = 'openai_json'
name: ClassVar[str] = 'OpenAI JSON'
data_schema: Schema = schema(
{
'messages': [
Expand Down Expand Up @@ -88,7 +88,7 @@ class OpenAIConversationJSON(DatasetFormat):
Note that here "messages" is "conversation" for support with common datasets.
"""

name: ClassVar[str] = 'openai_conversation_json'
name: ClassVar[str] = 'OpenAI Conversation JSON'
data_schema: Schema = schema(
{
'conversation': [
Expand Down
2 changes: 1 addition & 1 deletion lilac/formats/openchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class OpenChat(DatasetFormat):
"""OpenChat format."""

name: ClassVar[str] = 'openchat'
name: ClassVar[str] = 'OpenChat'
data_schema: Schema = schema(
{
'items': [
Expand Down
4 changes: 2 additions & 2 deletions lilac/formats/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _sharegpt_selector(item: Item, conv_from: str) -> str:
class ShareGPT(DatasetFormat):
"""ShareGPT format."""

name: ClassVar[str] = 'sharegpt'
name: ClassVar[str] = 'ShareGPT'
data_schema: Schema = schema(
{
'conversations': [
Expand All @@ -59,5 +59,5 @@ class ShareGPT(DatasetFormat):

input_selectors: ClassVar[dict[str, DatasetFormatInputSelector]] = {
selector.name: selector
for selector in [_SYSTEM_SELECTOR, _HUMAN_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR]
for selector in [_HUMAN_SELECTOR, _SYSTEM_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR]
}
2 changes: 1 addition & 1 deletion lilac/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _test_topic_fn(docs: list[tuple[str, float]]) -> str:
dataset_namespace='namespace',
dataset_name='test',
input_selector=ClusterInputSelectorConfig(
format='sharegpt',
format='ShareGPT',
selector='human',
),
output_path=('cluster',),
Expand Down
10 changes: 10 additions & 0 deletions lilac/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,13 @@ def restore_rows(
searches=options.searches,
filters=sanitized_filters,
)


@router.get('/{namespace}/{dataset_name}/format_selectors')
def get_format_selectors(namespace: str, dataset_name: str) -> list[str]:
"""Get format selectors for the dataset if a format has been inferred."""
dataset = get_dataset(namespace, dataset_name)
manifest = dataset.manifest()
if manifest.dataset_format:
return list(manifest.dataset_format.input_selectors.keys())
return []
52 changes: 45 additions & 7 deletions lilac/router_dataset_signals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Routing endpoints for running signals on datasets."""
from typing import Annotated, Optional
from typing import Annotated, Optional, Union

from fastapi import APIRouter, HTTPException
from fastapi.params import Depends
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic import Field as PydanticField

from .auth import UserInfo, get_session_user, get_user_access
from .data.clustering import default_cluster_output_path
from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls
from .db_manager import get_dataset
from .router_utils import RouteErrorHandler
from .schema import Path
from .schema import Path, PathTuple, normalize_path
from .signal import Signal, resolve_signal
from .tasks import TaskId, get_task_manager, launch_task

Expand Down Expand Up @@ -82,7 +84,9 @@ def run() -> None:
class ClusterOptions(BaseModel):
"""The request for the cluster endpoint."""

input: Path
input: Optional[Path] = None
input_selector: Optional[str] = None

output_path: Optional[Path] = None
use_garden: bool = PydanticField(
default=False, description='Accelerate computation by running remotely on Lilac Garden.'
Expand All @@ -107,14 +111,36 @@ def cluster(
if not get_user_access(user).dataset.compute_signals:
raise HTTPException(401, 'User does not have access to compute clusters over this dataset.')

path_str = '.'.join(map(str, options.input))
task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"'
task_id = get_task_manager().task_id(name=task_name)
dataset = get_dataset(namespace, dataset_name)
manifest = dataset.manifest()

cluster_input: Optional[Union[DatasetFormatInputSelector, PathTuple]] = None
if options.input:
path_str = '.'.join(map(str, options.input))
task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"'
cluster_input = normalize_path(options.input)
elif options.input_selector:
dataset_format = manifest.dataset_format
if dataset_format is None:
raise ValueError('Dataset format is not defined.')

format_cls = get_dataset_format_cls(dataset_format.name)
if format_cls is None:
raise ValueError(f'Unknown format: {dataset_format.name}')

cluster_input = format_cls.input_selectors[options.input_selector]

task_name = (
f'[{namespace}/{dataset_name}] Clustering using input selector ' f'"{options.input_selector}"'
)
else:
raise HTTPException(400, 'Either input or input_selector must be provided.')

task_id = get_task_manager().task_id(name=task_name)

def run() -> None:
dataset.cluster(
options.input,
cluster_input,
options.output_path,
use_garden=options.use_garden,
overwrite=options.overwrite,
Expand All @@ -125,6 +151,18 @@ def run() -> None:
return ClusterResponse(task_id=task_id)


class DefaultClusterOutputPathOptions(BaseModel):
"""Request body for the default cluster output path endpoint."""

input_path: Path


@router.post('/{namespace}/{dataset_name}/default_cluster_output_path')
def get_default_cluster_output_path(options: DefaultClusterOutputPathOptions) -> Path:
"""Get format selectors for the dataset if a format has been inferred."""
return default_cluster_output_path(options.input_path)


class DeleteSignalOptions(BaseModel):
"""The request for the delete signal endpoint."""

Expand Down
155 changes: 119 additions & 36 deletions web/blueprint/src/lib/components/ComputeClusterModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@
</script>

<script lang="ts">
import {clusterMutation} from '$lib/queries/datasetQueries';
import {
clusterMutation,
queryDatasetManifest,
queryDefaultClusterOutputPath,
queryFormatSelectors
} from '$lib/queries/datasetQueries';
import {queryAuthInfo} from '$lib/queries/serverQueries';
import type {Path} from '$lilac';
import {serializePath, type Path} from '$lilac';
import {
ComposedModal,
ModalBody,
ModalFooter,
ModalHeader,
Select,
SelectItem,
Toggle
} from 'carbon-components-svelte';
import FieldSelect from './commands/selectors/FieldSelect.svelte';
Expand All @@ -36,18 +43,63 @@
$: canComputeRemotely = $authInfo.data?.access.dataset.execute_remotely;
$: formatSelectorsQuery =
options != null ? queryFormatSelectors(options.namespace, options.datasetName) : null;
$: datasetManifest =
options != null ? queryDatasetManifest(options.namespace, options.datasetName) : null;
let selectedFormatSelector = 'none';
let formatSelectors: string[] | undefined = undefined;
let outputColumn: string | undefined = undefined;
$: outputColumnRequired =
formatSelectors != null &&
formatSelectors.length > 0 &&
selectedFormatSelector != null &&
selectedFormatSelector != 'none';
$: defaultClusterOutputPath = options?.input
? queryDefaultClusterOutputPath({input_path: options.input})
: null;
$: {
if ($defaultClusterOutputPath?.data != null) {
outputColumn = serializePath($defaultClusterOutputPath.data);
}
}
$: {
if (options?.output_path != null) {
outputColumn = serializePath(options.output_path);
}
}
$: {
if (
formatSelectorsQuery != null &&
$formatSelectorsQuery != null &&
$formatSelectorsQuery.data != null
) {
formatSelectors = $formatSelectorsQuery.data;
}
}
$: {
if (selectedFormatSelector != null && selectedFormatSelector != 'none') {
// Choose a reasonable default output column.
outputColumn = `${selectedFormatSelector}__clusters`;
} else if (selectedFormatSelector === 'none') {
outputColumn = undefined;
}
}
function close() {
store.set(null);
}
function submit() {
if (!options) return;
$clusterQuery.mutate([
options.namespace,
options.datasetName,
{
input: options.input,
input: selectedFormatSelector == null ? options.input : null,
use_garden: options.use_garden,
output_path: options.output_path,
output_path: outputColumn,
input_selector: selectedFormatSelector,
overwrite: options.overwrite
}
]);
Expand All @@ -59,47 +111,78 @@
<ComposedModal open on:submit={submit} on:close={close}>
<ModalHeader title="Compute clusters" />
<ModalBody hasForm>
<div class="max-w-2xl">
<FieldSelect
filter={f => f.dtype?.type === 'string'}
defaultPath={options.input}
bind:path={options.input}
labelText="Field"
/>
</div>
<div class="mt-8">
<div class="label mb-2 font-medium text-gray-700">Use Garden</div>
<div class="label mb-2 text-sm text-gray-700">
Accelerate computation by running remotely on <a
href="https://lilacml.com/#garden"
target="_blank">Lilac Garden</a
>
<div class="flex max-w-2xl flex-col gap-y-8">
<div>
<FieldSelect
disabled={selectedFormatSelector != null && selectedFormatSelector != 'none'}
filter={f => f.dtype?.type === 'string'}
defaultPath={options.input}
bind:path={options.input}
labelText="Field"
/>
</div>
<Toggle
disabled={!canComputeRemotely}
labelA={'False'}
labelB={'True'}
bind:toggled={options.use_garden}
hideLabel
/>
{#if !canComputeRemotely}
<div class="mt-2">
<a href="https://forms.gle/Gz9cpeKJccNar5Lq8" target="_blank">
Sign up for Lilac Garden
</a>
to enable this feature.
{#if formatSelectors != null && formatSelectors.length > 0}
<div>
<div class="label text-s mb-2 font-medium text-gray-700">
{$datasetManifest?.data?.dataset_manifest.dataset_format?.['format_name']} selector
</div>
<Select hideLabel={true} bind:selected={selectedFormatSelector} required>
<SelectItem value={'none'} text={'None'} />

{#each formatSelectors as formatSelector}
<SelectItem value={formatSelector} text={formatSelector} />
{/each}
</Select>
</div>
{/if}
</div>
<div class="mt-8">
<div class="label text-s mb-2 font-medium text-gray-700">Overwrite</div>
<Toggle labelA={'False'} labelB={'True'} bind:toggled={options.overwrite} hideLabel />
<div>
<div class="label text-s mb-2 font-medium text-gray-700">
{outputColumnRequired ? '*' : ''} Output column {!outputColumnRequired
? '(Optional)'
: ''}
</div>
<input
required={outputColumnRequired}
class="h-full w-full rounded border border-neutral-300 p-2"
placeholder="Choose a new column name to write clusters"
bind:value={outputColumn}
/>
</div>
<div>
<div class="label mb-2 font-medium text-gray-700">Use Garden</div>
<div class="label text-sm text-gray-700">
Accelerate computation by running remotely on <a
href="https://lilacml.com/#garden"
target="_blank">Lilac Garden</a
>
</div>
<Toggle
disabled={!canComputeRemotely}
labelA={'False'}
labelB={'True'}
bind:toggled={options.use_garden}
hideLabel
/>
{#if !canComputeRemotely}
<div>
<a href="https://forms.gle/Gz9cpeKJccNar5Lq8" target="_blank">
Sign up for Lilac Garden
</a>
to enable this feature.
</div>
{/if}
</div>
<div>
<div class="label text-s mb-2 font-medium text-gray-700">Overwrite</div>
<Toggle labelA={'False'} labelB={'True'} bind:toggled={options.overwrite} hideLabel />
</div>
</div>
</ModalBody>
<ModalFooter
primaryButtonText="Cluster"
secondaryButtonText="Cancel"
on:click:button--secondary={close}
primaryButtonDisabled={outputColumnRequired && !outputColumn}
/>
</ComposedModal>
{/if}
Loading

0 comments on commit 8e7418d

Please sign in to comment.