Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add format selectors to the compute clusters UI. #1185

Merged
merged 7 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lilac/formats/openai_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OpenAIJSON(DatasetFormat):
Taken from: https://platform.openai.com/docs/api-reference/chat
"""

name: ClassVar[str] = 'openai_json'
name: ClassVar[str] = 'OpenAI JSON'
data_schema: Schema = schema(
{
'messages': [
Expand Down Expand Up @@ -88,7 +88,7 @@ class OpenAIConversationJSON(DatasetFormat):
Note that here "messages" is "conversation" for support with common datasets.
"""

name: ClassVar[str] = 'openai_conversation_json'
name: ClassVar[str] = 'OpenAI Conversation JSON'
data_schema: Schema = schema(
{
'conversation': [
Expand Down
2 changes: 1 addition & 1 deletion lilac/formats/openchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class OpenChat(DatasetFormat):
"""OpenChat format."""

name: ClassVar[str] = 'openchat'
name: ClassVar[str] = 'OpenChat'
data_schema: Schema = schema(
{
'items': [
Expand Down
4 changes: 2 additions & 2 deletions lilac/formats/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _sharegpt_selector(item: Item, conv_from: str) -> str:
class ShareGPT(DatasetFormat):
"""ShareGPT format."""

name: ClassVar[str] = 'sharegpt'
name: ClassVar[str] = 'ShareGPT'
data_schema: Schema = schema(
{
'conversations': [
Expand All @@ -59,5 +59,5 @@ class ShareGPT(DatasetFormat):

input_selectors: ClassVar[dict[str, DatasetFormatInputSelector]] = {
selector.name: selector
for selector in [_SYSTEM_SELECTOR, _HUMAN_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR]
for selector in [_HUMAN_SELECTOR, _SYSTEM_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR]
}
2 changes: 1 addition & 1 deletion lilac/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _test_topic_fn(docs: list[tuple[str, float]]) -> str:
dataset_namespace='namespace',
dataset_name='test',
input_selector=ClusterInputSelectorConfig(
format='sharegpt',
format='ShareGPT',
selector='human',
),
output_path=('cluster',),
Expand Down
10 changes: 10 additions & 0 deletions lilac/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,13 @@ def restore_rows(
searches=options.searches,
filters=sanitized_filters,
)


@router.get('/{namespace}/{dataset_name}/format_selectors')
def get_format_selectors(namespace: str, dataset_name: str) -> list[str]:
"""Get format selectors for the dataset if a format has been inferred."""
dataset = get_dataset(namespace, dataset_name)
manifest = dataset.manifest()
if manifest.dataset_format:
return list(manifest.dataset_format.input_selectors.keys())
return []
39 changes: 32 additions & 7 deletions lilac/router_dataset_signals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Routing endpoints for running signals on datasets."""
from typing import Annotated, Optional
from typing import Annotated, Optional, Union

from fastapi import APIRouter, HTTPException
from fastapi.params import Depends
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic import Field as PydanticField

from .auth import UserInfo, get_session_user, get_user_access
from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls
from .db_manager import get_dataset
from .router_utils import RouteErrorHandler
from .schema import Path
from .schema import Path, PathTuple, normalize_path
from .signal import Signal, resolve_signal
from .tasks import TaskId, get_task_manager, launch_task

Expand Down Expand Up @@ -82,7 +83,9 @@ def run() -> None:
class ClusterOptions(BaseModel):
"""The request for the cluster endpoint."""

input: Path
input: Optional[Path] = None
input_selector: Optional[str] = None

output_path: Optional[Path] = None
use_garden: bool = PydanticField(
default=False, description='Accelerate computation by running remotely on Lilac Garden.'
Expand All @@ -107,14 +110,36 @@ def cluster(
if not get_user_access(user).dataset.compute_signals:
raise HTTPException(401, 'User does not have access to compute clusters over this dataset.')

path_str = '.'.join(map(str, options.input))
task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"'
task_id = get_task_manager().task_id(name=task_name)
dataset = get_dataset(namespace, dataset_name)
manifest = dataset.manifest()

cluster_input: Optional[Union[DatasetFormatInputSelector, PathTuple]] = None
if options.input:
path_str = '.'.join(map(str, options.input))
task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"'
cluster_input = normalize_path(options.input)
elif options.input_selector:
dataset_format = manifest.dataset_format
if dataset_format is None:
raise ValueError('Dataset format is not defined.')

format_cls = get_dataset_format_cls(dataset_format.name)
if format_cls is None:
raise ValueError(f'Unknown format: {dataset_format.name}')

cluster_input = format_cls.input_selectors[options.input_selector]

task_name = (
f'[{namespace}/{dataset_name}] Clustering using input selector ' f'"{options.input_selector}"'
)
else:
raise HTTPException(400, 'Either input or input_selector must be provided.')

task_id = get_task_manager().task_id(name=task_name)

def run() -> None:
dataset.cluster(
options.input,
cluster_input,
options.output_path,
use_garden=options.use_garden,
overwrite=options.overwrite,
Expand Down
138 changes: 102 additions & 36 deletions web/blueprint/src/lib/components/ComputeClusterModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
</script>

<script lang="ts">
import {clusterMutation} from '$lib/queries/datasetQueries';
import {
clusterMutation,
queryDatasetManifest,
queryFormatSelectors
} from '$lib/queries/datasetQueries';
import {queryAuthInfo} from '$lib/queries/serverQueries';
import type {Path} from '$lilac';
import {serializePath, type Path} from '$lilac';
import {
ComposedModal,
ModalBody,
ModalFooter,
ModalHeader,
Select,
SelectItem,
Toggle
} from 'carbon-components-svelte';
import FieldSelect from './commands/selectors/FieldSelect.svelte';
Expand All @@ -36,18 +42,47 @@

$: canComputeRemotely = $authInfo.data?.access.dataset.execute_remotely;

$: formatSelectorsQuery =
options != null ? queryFormatSelectors(options.namespace, options.datasetName) : null;
$: datasetManifest =
options != null ? queryDatasetManifest(options.namespace, options.datasetName) : null;
let selectedFormatSelector = 'none';
let formatSelectors: string[] | undefined = undefined;
let outputColumn: string | undefined = undefined;
$: outputColumnRequired =
formatSelectors != null &&
formatSelectors.length > 0 &&
selectedFormatSelector != null &&
selectedFormatSelector != 'none';
$: {
if (options?.output_path != null) {
outputColumn = serializePath(options.output_path);
}
}
$: {
if (
formatSelectorsQuery != null &&
$formatSelectorsQuery != null &&
$formatSelectorsQuery.data != null
) {
formatSelectors = $formatSelectorsQuery.data;
}
}

function close() {
store.set(null);
}
function submit() {
if (!options) return;

$clusterQuery.mutate([
options.namespace,
options.datasetName,
{
input: options.input,
input: selectedFormatSelector == null ? options.input : null,
use_garden: options.use_garden,
output_path: options.output_path,
output_path: outputColumn,
input_selector: selectedFormatSelector,
overwrite: options.overwrite
}
]);
Expand All @@ -59,47 +94,78 @@
<ComposedModal open on:submit={submit} on:close={close}>
<ModalHeader title="Compute clusters" />
<ModalBody hasForm>
<div class="max-w-2xl">
<FieldSelect
filter={f => f.dtype?.type === 'string'}
defaultPath={options.input}
bind:path={options.input}
labelText="Field"
/>
</div>
<div class="mt-8">
<div class="label mb-2 font-medium text-gray-700">Use Garden</div>
<div class="label mb-2 text-sm text-gray-700">
Accelerate computation by running remotely on <a
href="https://lilacml.com/#garden"
target="_blank">Lilac Garden</a
>
<div class="flex max-w-2xl flex-col gap-y-8">
<div>
<FieldSelect
disabled={selectedFormatSelector != null && selectedFormatSelector != 'none'}
filter={f => f.dtype?.type === 'string'}
defaultPath={options.input}
bind:path={options.input}
labelText="Field"
/>
</div>
<Toggle
disabled={!canComputeRemotely}
labelA={'False'}
labelB={'True'}
bind:toggled={options.use_garden}
hideLabel
/>
{#if !canComputeRemotely}
<div class="mt-2">
<a href="https://forms.gle/Gz9cpeKJccNar5Lq8" target="_blank">
Sign up for Lilac Garden
</a>
to enable this feature.
{#if formatSelectors != null && formatSelectors.length > 0}
<div>
<div class="label text-s mb-2 font-medium text-gray-700">
{$datasetManifest?.data?.dataset_manifest.dataset_format?.['format_name']} selector
</div>
<Select hideLabel={true} bind:selected={selectedFormatSelector} required>
<SelectItem value={'none'} text={'None'} />

{#each formatSelectors as formatSelector}
<SelectItem value={formatSelector} text={formatSelector} />
{/each}
</Select>
</div>
{/if}
</div>
<div class="mt-8">
<div class="label text-s mb-2 font-medium text-gray-700">Overwrite</div>
<Toggle labelA={'False'} labelB={'True'} bind:toggled={options.overwrite} hideLabel />
<div>
<div class="label text-s mb-2 font-medium text-gray-700">
{outputColumnRequired ? '*' : ''} Output column {!outputColumnRequired
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since output column is optional, is it possible to populate the input box with the default column name? (informs the person where results will end up)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

? '(Optional)'
: ''}
</div>
<input
required={outputColumnRequired}
class="h-full w-full rounded border border-neutral-300 p-2"
placeholder="Choose a new column name to write clusters"
bind:value={outputColumn}
/>
</div>
<div>
<div class="label mb-2 font-medium text-gray-700">Use Garden</div>
<div class="label text-sm text-gray-700">
Accelerate computation by running remotely on <a
href="https://lilacml.com/#garden"
target="_blank">Lilac Garden</a
>
</div>
<Toggle
disabled={!canComputeRemotely}
labelA={'False'}
labelB={'True'}
bind:toggled={options.use_garden}
hideLabel
/>
{#if !canComputeRemotely}
<div>
<a href="https://forms.gle/Gz9cpeKJccNar5Lq8" target="_blank">
Sign up for Lilac Garden
</a>
to enable this feature.
</div>
{/if}
</div>
<div>
<div class="label text-s mb-2 font-medium text-gray-700">Overwrite</div>
<Toggle labelA={'False'} labelB={'True'} bind:toggled={options.overwrite} hideLabel />
</div>
</div>
</ModalBody>
<ModalFooter
primaryButtonText="Cluster"
secondaryButtonText="Cancel"
on:click:button--secondary={close}
primaryButtonDisabled={outputColumnRequired && !outputColumn}
/>
</ComposedModal>
{/if}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

export let defaultPath: Path | undefined = undefined;
export let path: Path | undefined = undefined;
export let disabled = false;

const datasetViewStore = getDatasetViewContext();

Expand Down Expand Up @@ -83,7 +84,7 @@
<div class="label text-s mb-2 font-medium text-gray-700">
{labelText}
</div>
<Select hideLabel={true} {helperText} bind:selected={selectedPath} required>
<Select hideLabel={true} {helperText} bind:selected={selectedPath} required {disabled}>
{#if sourceFields?.length}
<SelectItemGroup label="Source Fields">
{#each sourceFields as field}
Expand Down
4 changes: 4 additions & 0 deletions web/blueprint/src/lib/queries/datasetQueries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,7 @@ function invalidateQueriesLabelEdit(
]);
}
}
export const queryFormatSelectors = createApiQuery(
DatasetsService.getFormatSelectors,
DATASETS_TAG
);
3 changes: 2 additions & 1 deletion web/lib/fastapi_client/models/ClusterOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
* The request for the cluster endpoint.
*/
export type ClusterOptions = {
input: (Array<string> | string);
input?: (Array<string> | string | null);
input_selector?: (string | null);
output_path?: (Array<string> | string | null);
/**
* Accelerate computation by running remotely on Lilac Garden.
Expand Down
25 changes: 25 additions & 0 deletions web/lib/fastapi_client/services/DatasetsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,31 @@ export class DatasetsService {
});
}

/**
* Get Format Selectors
* Get format selectors for the dataset if a format has been inferred.
* @param namespace
* @param datasetName
* @returns string Successful Response
* @throws ApiError
*/
public static getFormatSelectors(
namespace: string,
datasetName: string,
): CancelablePromise<Array<string>> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/datasets/{namespace}/{dataset_name}/format_selectors',
path: {
'namespace': namespace,
'dataset_name': datasetName,
},
errors: {
422: `Validation Error`,
},
});
}

/**
* Compute Signal
* Compute a signal for a dataset.
Expand Down
Loading