diff --git a/lilac/data/clustering.py b/lilac/data/clustering.py index a9762372..65eacece 100644 --- a/lilac/data/clustering.py +++ b/lilac/data/clustering.py @@ -104,17 +104,7 @@ def cluster_impl( if output_path: cluster_output_path = normalize_path(output_path) elif path: - # The sibling output path is the same as the input path, but with a different suffix. - index = 0 - for i, path_part in enumerate(path): - if path_part == PATH_WILDCARD: - break - else: - index = i - - parent = path[:index] - sibling = '_'.join([p for p in path[index:] if p != PATH_WILDCARD]) - cluster_output_path = (*parent, f'{sibling}__{FIELD_SUFFIX}') + cluster_output_path = default_cluster_output_path(path) else: raise ValueError('input must be provided.') @@ -416,3 +406,19 @@ def _hdbscan_cluster( for cluster_id, membership_prob in zip(labels, memberships): yield {CLUSTER_ID: int(cluster_id), CLUSTER_MEMBERSHIP_PROB: float(membership_prob)} + + +def default_cluster_output_path(input_path: Path) -> Path: + """Default output path for clustering.""" + input_path = normalize_path(input_path) + # The sibling output path is the same as the input path, but with a different suffix. + index = 0 + for i, path_part in enumerate(input_path): + if path_part == PATH_WILDCARD: + break + else: + index = i + + parent = input_path[:index] + sibling = '_'.join([p for p in input_path[index:] if p != PATH_WILDCARD]) + return (*parent, f'{sibling}__{FIELD_SUFFIX}') diff --git a/lilac/router_dataset_signals.py b/lilac/router_dataset_signals.py index 70066875..4c4f4d8d 100644 --- a/lilac/router_dataset_signals.py +++ b/lilac/router_dataset_signals.py @@ -7,6 +7,7 @@ from pydantic import Field as PydanticField from .auth import UserInfo, get_session_user, get_user_access +from .data.clustering import default_cluster_output_path from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls from .db_manager import get_dataset from .router_utils import RouteErrorHandler @@ -150,6 +151,18 @@ def run() -> None: return ClusterResponse(task_id=task_id) +class DefaultClusterOutputPathOptions(BaseModel): + """Request body for the default cluster output path endpoint.""" + + input_path: Path + + +@router.post('/{namespace}/{dataset_name}/default_cluster_output_path') +def get_default_cluster_output_path(options: DefaultClusterOutputPathOptions) -> Path: + """Get format selectors for the dataset if a format has been inferred.""" + return default_cluster_output_path(options.input_path) + + class DeleteSignalOptions(BaseModel): """The request for the delete signal endpoint.""" diff --git a/web/blueprint/src/lib/components/ComputeClusterModal.svelte b/web/blueprint/src/lib/components/ComputeClusterModal.svelte index 604ac1b6..1380a7af 100644 --- a/web/blueprint/src/lib/components/ComputeClusterModal.svelte +++ b/web/blueprint/src/lib/components/ComputeClusterModal.svelte @@ -21,6 +21,7 @@ import { clusterMutation, queryDatasetManifest, + queryDefaultClusterOutputPath, queryFormatSelectors } from '$lib/queries/datasetQueries'; import {queryAuthInfo} from '$lib/queries/serverQueries'; @@ -54,6 +55,14 @@ formatSelectors.length > 0 && selectedFormatSelector != null && selectedFormatSelector != 'none'; + $: defaultClusterOutputPath = options?.input + ? queryDefaultClusterOutputPath({input_path: options.input}) + : null; + $: { + if ($defaultClusterOutputPath?.data != null) { + outputColumn = serializePath($defaultClusterOutputPath.data); + } + } $: { if (options?.output_path != null) { outputColumn = serializePath(options.output_path); diff --git a/web/blueprint/src/lib/queries/datasetQueries.ts b/web/blueprint/src/lib/queries/datasetQueries.ts index 7b7373d5..608661e4 100644 --- a/web/blueprint/src/lib/queries/datasetQueries.ts +++ b/web/blueprint/src/lib/queries/datasetQueries.ts @@ -337,3 +337,8 @@ export const queryFormatSelectors = createApiQuery( DatasetsService.getFormatSelectors, DATASETS_TAG ); + +export const queryDefaultClusterOutputPath = createApiQuery( + DatasetsService.getDefaultClusterOutputPath, + DATASETS_TAG +); diff --git a/web/lib/fastapi_client/index.ts b/web/lib/fastapi_client/index.ts index b58bbe55..50ed213d 100644 --- a/web/lib/fastapi_client/index.ts +++ b/web/lib/fastapi_client/index.ts @@ -37,6 +37,7 @@ export type { DatasetSettings } from './models/DatasetSettings'; export type { DatasetUISettings } from './models/DatasetUISettings'; export type { DatasetUserAccess } from './models/DatasetUserAccess'; export type { DataType } from './models/DataType'; +export type { DefaultClusterOutputPathOptions } from './models/DefaultClusterOutputPathOptions'; export type { DeleteRowsOptions } from './models/DeleteRowsOptions'; export type { DeleteSignalOptions } from './models/DeleteSignalOptions'; export type { DeleteSignalResponse } from './models/DeleteSignalResponse'; diff --git a/web/lib/fastapi_client/models/DefaultClusterOutputPathOptions.ts b/web/lib/fastapi_client/models/DefaultClusterOutputPathOptions.ts new file mode 100644 index 00000000..459df2f3 --- /dev/null +++ b/web/lib/fastapi_client/models/DefaultClusterOutputPathOptions.ts @@ -0,0 +1,12 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * Request body for the default cluster output path endpoint. + */ +export type DefaultClusterOutputPathOptions = { + input_path: (Array | string); +}; + diff --git a/web/lib/fastapi_client/services/DatasetsService.ts b/web/lib/fastapi_client/services/DatasetsService.ts index c2b8c0c3..9bc66be4 100644 --- a/web/lib/fastapi_client/services/DatasetsService.ts +++ b/web/lib/fastapi_client/services/DatasetsService.ts @@ -9,6 +9,7 @@ import type { ComputeSignalOptions } from '../models/ComputeSignalOptions'; import type { ComputeSignalResponse } from '../models/ComputeSignalResponse'; import type { DatasetInfo } from '../models/DatasetInfo'; import type { DatasetSettings } from '../models/DatasetSettings'; +import type { DefaultClusterOutputPathOptions } from '../models/DefaultClusterOutputPathOptions'; import type { DeleteRowsOptions } from '../models/DeleteRowsOptions'; import type { DeleteSignalOptions } from '../models/DeleteSignalOptions'; import type { DeleteSignalResponse } from '../models/DeleteSignalResponse'; @@ -608,6 +609,27 @@ export class DatasetsService { }); } + /** + * Get Default Cluster Output Path + * Get format selectors for the dataset if a format has been inferred. + * @param requestBody + * @returns any Successful Response + * @throws ApiError + */ + public static getDefaultClusterOutputPath( + requestBody: DefaultClusterOutputPathOptions, + ): CancelablePromise<(Array | string)> { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/datasets/{namespace}/{dataset_name}/default_cluster_output_path', + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + /** * Delete Signal * Delete a signal from a dataset.