Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 17, 2024
1 parent d14c3d2 commit 5f15662
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 11 deletions.
28 changes: 17 additions & 11 deletions lilac/data/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,7 @@ def cluster_impl(
if output_path:
cluster_output_path = normalize_path(output_path)
elif path:
# The sibling output path is the same as the input path, but with a different suffix.
index = 0
for i, path_part in enumerate(path):
if path_part == PATH_WILDCARD:
break
else:
index = i

parent = path[:index]
sibling = '_'.join([p for p in path[index:] if p != PATH_WILDCARD])
cluster_output_path = (*parent, f'{sibling}__{FIELD_SUFFIX}')
cluster_output_path = default_cluster_output_path(path)
else:
raise ValueError('input must be provided.')

Expand Down Expand Up @@ -416,3 +406,19 @@ def _hdbscan_cluster(

for cluster_id, membership_prob in zip(labels, memberships):
yield {CLUSTER_ID: int(cluster_id), CLUSTER_MEMBERSHIP_PROB: float(membership_prob)}


def default_cluster_output_path(input_path: Path) -> Path:
"""Default output path for clustering."""
input_path = normalize_path(input_path)
# The sibling output path is the same as the input path, but with a different suffix.
index = 0
for i, path_part in enumerate(input_path):
if path_part == PATH_WILDCARD:
break
else:
index = i

parent = input_path[:index]
sibling = '_'.join([p for p in input_path[index:] if p != PATH_WILDCARD])
return (*parent, f'{sibling}__{FIELD_SUFFIX}')
13 changes: 13 additions & 0 deletions lilac/router_dataset_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import Field as PydanticField

from .auth import UserInfo, get_session_user, get_user_access
from .data.clustering import default_cluster_output_path
from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls
from .db_manager import get_dataset
from .router_utils import RouteErrorHandler
Expand Down Expand Up @@ -150,6 +151,18 @@ def run() -> None:
return ClusterResponse(task_id=task_id)


class DefaultClusterOutputPathOptions(BaseModel):
"""Request body for the default cluster output path endpoint."""

input_path: Path


@router.post('/{namespace}/{dataset_name}/default_cluster_output_path')
def get_default_cluster_output_path(options: DefaultClusterOutputPathOptions) -> Path:
"""Get format selectors for the dataset if a format has been inferred."""
return default_cluster_output_path(options.input_path)


class DeleteSignalOptions(BaseModel):
"""The request for the delete signal endpoint."""

Expand Down
9 changes: 9 additions & 0 deletions web/blueprint/src/lib/components/ComputeClusterModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import {
clusterMutation,
queryDatasetManifest,
queryDefaultClusterOutputPath,
queryFormatSelectors
} from '$lib/queries/datasetQueries';
import {queryAuthInfo} from '$lib/queries/serverQueries';
Expand Down Expand Up @@ -54,6 +55,14 @@
formatSelectors.length > 0 &&
selectedFormatSelector != null &&
selectedFormatSelector != 'none';
$: defaultClusterOutputPath = options?.input
? queryDefaultClusterOutputPath({input_path: options.input})
: null;
$: {
if ($defaultClusterOutputPath?.data != null) {
outputColumn = serializePath($defaultClusterOutputPath.data);
}
}
$: {
if (options?.output_path != null) {
outputColumn = serializePath(options.output_path);
Expand Down
5 changes: 5 additions & 0 deletions web/blueprint/src/lib/queries/datasetQueries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,8 @@ export const queryFormatSelectors = createApiQuery(
DatasetsService.getFormatSelectors,
DATASETS_TAG
);

export const queryDefaultClusterOutputPath = createApiQuery(
DatasetsService.getDefaultClusterOutputPath,
DATASETS_TAG
);
1 change: 1 addition & 0 deletions web/lib/fastapi_client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export type { DatasetSettings } from './models/DatasetSettings';
export type { DatasetUISettings } from './models/DatasetUISettings';
export type { DatasetUserAccess } from './models/DatasetUserAccess';
export type { DataType } from './models/DataType';
export type { DefaultClusterOutputPathOptions } from './models/DefaultClusterOutputPathOptions';
export type { DeleteRowsOptions } from './models/DeleteRowsOptions';
export type { DeleteSignalOptions } from './models/DeleteSignalOptions';
export type { DeleteSignalResponse } from './models/DeleteSignalResponse';
Expand Down
12 changes: 12 additions & 0 deletions web/lib/fastapi_client/models/DefaultClusterOutputPathOptions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */

/**
* Request body for the default cluster output path endpoint.
*/
export type DefaultClusterOutputPathOptions = {
input_path: (Array<string> | string);
};

22 changes: 22 additions & 0 deletions web/lib/fastapi_client/services/DatasetsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { ComputeSignalOptions } from '../models/ComputeSignalOptions';
import type { ComputeSignalResponse } from '../models/ComputeSignalResponse';
import type { DatasetInfo } from '../models/DatasetInfo';
import type { DatasetSettings } from '../models/DatasetSettings';
import type { DefaultClusterOutputPathOptions } from '../models/DefaultClusterOutputPathOptions';
import type { DeleteRowsOptions } from '../models/DeleteRowsOptions';
import type { DeleteSignalOptions } from '../models/DeleteSignalOptions';
import type { DeleteSignalResponse } from '../models/DeleteSignalResponse';
Expand Down Expand Up @@ -608,6 +609,27 @@ export class DatasetsService {
});
}

/**
* Get Default Cluster Output Path
* Get format selectors for the dataset if a format has been inferred.
* @param requestBody
* @returns any Successful Response
* @throws ApiError
*/
public static getDefaultClusterOutputPath(
requestBody: DefaultClusterOutputPathOptions,
): CancelablePromise<(Array<string> | string)> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/datasets/{namespace}/{dataset_name}/default_cluster_output_path',
body: requestBody,
mediaType: 'application/json',
errors: {
422: `Validation Error`,
},
});
}

/**
* Delete Signal
* Delete a signal from a dataset.
Expand Down

0 comments on commit 5f15662

Please sign in to comment.