diff --git a/ui/bun.lockb b/ui/bun.lockb
index 2194dfd..3bd98f1 100755
Binary files a/ui/bun.lockb and b/ui/bun.lockb differ
diff --git a/ui/package.json b/ui/package.json
index 0c8a211..717bf84 100644
--- a/ui/package.json
+++ b/ui/package.json
@@ -12,18 +12,23 @@
"dependencies": {
"@msgpack/msgpack": "^3.0.0-beta2",
"@radix-ui/react-accordion": "^1.1.2",
+ "@radix-ui/react-dialog": "^1.1.1",
"@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-hover-card": "^1.0.7",
"@radix-ui/react-label": "^2.0.2",
+ "@radix-ui/react-popover": "^1.1.1",
"@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-separator": "^1.0.3",
"@radix-ui/react-slot": "^1.0.2",
+ "@radix-ui/react-switch": "^1.1.0",
"@radix-ui/react-tabs": "^1.0.4",
"@radix-ui/react-toggle": "^1.0.3",
+ "@radix-ui/react-tooltip": "^1.1.2",
"@tanstack/react-table": "^8.15.3",
"camelcase-keys": "^9.1.3",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.0",
+ "cmdk": "1.0.0",
"lucide-react": "^0.358.0",
"plotly.js": "^2.30.1",
"react": "^18.2.0",
@@ -32,6 +37,7 @@
"react-router-dom": "^6.22.3",
"react-use": "^17.5.0",
"recharts": "^2.12.7",
+ "snakecase-keys": "^8.0.1",
"tailwind-merge": "^2.2.1",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",
diff --git a/ui/src/components/model/model-card.tsx b/ui/src/components/model/model-card.tsx
index ce14e7b..d89ab70 100644
--- a/ui/src/components/model/model-card.tsx
+++ b/ui/src/components/model/model-card.tsx
@@ -1,9 +1,10 @@
-import { Fragment, useEffect, useState } from "react";
+import { Fragment, useEffect, useMemo, useState } from "react";
import { Button } from "../ui/button";
import { Card, CardContent, CardHeader, CardTitle } from "../ui/card";
import { Textarea } from "../ui/textarea";
-import { useAsyncFn } from "react-use";
+import { useAsyncFn, useMount } from "react-use";
import camelcaseKeys from "camelcase-keys";
+import snakecaseKeys from "snakecase-keys";
import { decode } from "@msgpack/msgpack";
import { ModelGeneration, ModelGenerationSchema } from "@/types/model";
import { Sample } from "../app/sample";
@@ -14,53 +15,330 @@ import { Input } from "../ui/input";
import { countTokenGroupPositions, groupToken, hex } from "@/utils/token";
import { getAccentClassname } from "@/utils/style";
import { Separator } from "../ui/separator";
+import MultipleSelector from "../ui/multiple-selector";
+import { z } from "zod";
+import { Combobox } from "../ui/combobox";
+import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "../ui/select";
+import { Ban, MoreHorizontal, Plus, Trash2, Wrench, X } from "lucide-react";
+import { ColumnDef } from "@tanstack/react-table";
+import { FeatureLinkWithPreview } from "../app/feature-preview";
+import { DataTable } from "../ui/data-table";
+import {
+ DropdownMenu,
+ DropdownMenuContent,
+ DropdownMenuItem,
+ DropdownMenuLabel,
+ DropdownMenuSeparator,
+ DropdownMenuTrigger,
+} from "../ui/dropdown-menu";
+import { useNavigate } from "react-router-dom";
+import { Switch } from "../ui/switch";
+import { Label as SLabel } from "../ui/label";
+import { Toggle } from "../ui/toggle";
-const ModelSample = ({ sample }: { sample: ModelGeneration }) => {
+const SAEInfo = ({
+ saeInfo,
+ saeSettings,
+ onSteerFeature,
+ setSAESettings,
+}: {
+ saeInfo: {
+ name: string;
+ featureActs: {
+ featureActIndex: number;
+ featureAct: number;
+ maxFeatureAct: number;
+ }[];
+ };
+ saeSettings: { sortedBySum: boolean };
+ onSteerFeature?: (name: string, featureIndex: number) => void;
+ setSAESettings: (settings: { sortedBySum: boolean }) => void;
+}) => {
+ const navigate = useNavigate();
+
+ const columns: ColumnDef<{
+ featureActIndex: number;
+ featureAct: number;
+ maxFeatureAct: number;
+ }>[] = [
+ {
+ accessorKey: "featureActIndex",
+ header: () => (
+
+ Feature
+
+ ),
+ cell: ({ row }) => (
+
+
+
+ ),
+ },
+ {
+ accessorKey: "featureAct",
+ header: () => (
+
+ Activation
+
+ ),
+ cell: ({ row }) => (
+
+ {row.original.featureAct.toFixed(3)}
+
+ ),
+ },
+ {
+ id: "actions",
+ enableHiding: false,
+ cell: ({ row }) => (
+
+
+
+ Open menu
+
+
+
+
+ Actions
+ onSteerFeature?.(saeInfo.name, row.original.featureActIndex)}>
+ Steer feature
+
+
+ {
+ navigate(
+ `/features?dictionary=${encodeURIComponent(saeInfo.name)}&featureIndex=${
+ row.original.featureActIndex
+ }`
+ );
+ }}
+ >
+ View feature
+
+
+
+ ),
+ meta: {
+ headerClassName: "w-16",
+ cellClassName: "py-0",
+ },
+ },
+ ];
+
+ return (
+
+
+
Features from {saeInfo.name}:
+
setSAESettings({ ...saeSettings, sortedBySum: pressed })}
+ >
+ Sort by sum
+
+
+
+
+ );
+};
+
+const LogitsInfo = ({
+ logits,
+}: {
+ logits: {
+ logits: number;
+ token: Uint8Array;
+ }[];
+}) => {
+ const maxLogits = Math.max(...logits.map((logit) => logit.logits));
+ const columns: ColumnDef<{
+ logits: number;
+ token: Uint8Array;
+ }>[] = [
+ {
+ accessorKey: "token",
+ header: () => (
+
+ Token
+
+ ),
+ cell: ({ row }) => (
+
+ {hex(row.original)}
+
+ ),
+ },
+ {
+ accessorKey: "logits",
+ header: () => (
+
+ Logits
+
+ ),
+ cell: ({ row }) => (
+
+ {row.original.logits.toFixed(3)}
+
+ ),
+ },
+ ];
+
+ return (
+
+ );
+};
+
+const LogitsBarChart = ({
+ tokens,
+}: {
+ tokens: {
+ logits: {
+ logits: number;
+ token: Uint8Array;
+ }[];
+ token: Uint8Array;
+ }[];
+}) => {
+ const data = tokens.map((token) =>
+ Object.assign(
+ {},
+ ...token.logits.map((logit, j) => ({
+ [`logits-${j}`]: logit.logits,
+ [`logits-token-${j}`]: hex(logit),
+ })),
+ {
+ name: hex(token),
+ }
+ )
+ );
+
+ const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"];
+
+ return (
+
+
+
+
+
+
+
+ {tokens[0].logits.slice(0, 5).map((_, i) => (
+
+
+
+ ))}
+
+
+ );
+};
+
+const ModelSample = ({
+ sample,
+ onSteerFeature,
+}: {
+ sample: ModelGeneration;
+ onSteerFeature?: (name: string, featureIndex: number) => void;
+}) => {
const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState([]);
const toggleSelectedTokenGroupIndex = (tokenGroupIndex: number) => {
setSelectedTokenGroupIndices((prev) =>
- prev.includes(tokenGroupIndex) ? prev.filter((t) => t !== tokenGroupIndex) : [...prev, tokenGroupIndex]
+ prev.includes(tokenGroupIndex)
+ ? prev.filter((t) => t !== tokenGroupIndex)
+ : [...prev, tokenGroupIndex].sort((a, b) => a - b)
);
};
+ const [saeSettings, setSAESettings] = useState<{ [name: string]: { sortedBySum: boolean } }>({});
+ const getSAESettings = (name: string) => saeSettings[name] || { sortedBySum: false };
+
useEffect(() => {
setSelectedTokenGroupIndices([]);
}, [sample]);
- const tokens = zip(sample.context, sample.inputMask, sample.logits, sample.logitsTokens).map(
- ([token, inputMask, logits, logitsTokens]) => ({
- token,
- inputMask,
- logits: zip(logits, logitsTokens).map(([logits, token]) => ({
- logits,
+ const tokens = useMemo(() => {
+ const saeInfo =
+ sample.saeInfo.length > 0
+ ? zip(
+ ...sample.saeInfo.map((sae) =>
+ zip(sae.featureActsIndices, sae.featureActs, sae.maxFeatureActs).map(
+ ([featureActIndex, featureAct, maxFeatureAct]) => ({
+ name: sae.name,
+ featureActs: zip(featureActIndex, featureAct, maxFeatureAct).map(
+ ([featureActIndex, featureAct, maxFeatureAct]) => ({
+ featureActIndex,
+ featureAct,
+ maxFeatureAct,
+ })
+ ),
+ })
+ )
+ )
+ )
+ : sample.context.map(() => []);
+
+ return zip(sample.context, sample.inputMask, sample.logits, sample.logitsTokens, saeInfo).map(
+ ([token, inputMask, logits, logitsTokens, saeInfo]) => ({
token,
+ inputMask,
+ logits: zip(logits, logitsTokens).map(([logits, token]) => ({
+ logits,
+ token,
+ })),
+ saeInfo,
+ })
+ );
+ }, [sample]);
+
+ type Token = (typeof tokens)[0];
+
+ const sortTokenInfo = (tokens: Token[]) => {
+ const featureActSum = tokens.reduce((acc, token) => {
+ token.saeInfo.forEach((saeInfo) => {
+ saeInfo.featureActs.forEach((featureAct) => {
+ acc[saeInfo.name] = acc[saeInfo.name] || {};
+ acc[saeInfo.name][featureAct.featureActIndex.toString()] =
+ acc[saeInfo.name][featureAct.featureActIndex.toString()] || 0;
+ acc[saeInfo.name][featureAct.featureActIndex.toString()] += featureAct.featureAct;
+ });
+ });
+ return acc;
+ }, {} as { [name: string]: { [featureIndex: string]: number } });
+
+ return tokens.map((token) => ({
+ ...token,
+ logits: token.logits.sort((a, b) => b.logits - a.logits),
+ saeInfo: token.saeInfo.map((saeInfo) => ({
+ ...saeInfo,
+ featureActs: getSAESettings(saeInfo.name).sortedBySum
+ ? saeInfo.featureActs.sort(
+ (a, b) =>
+ featureActSum[saeInfo.name][b.featureActIndex.toString()] -
+ featureActSum[saeInfo.name][a.featureActIndex.toString()]
+ )
+ : saeInfo.featureActs.sort((a, b) => b.featureAct - a.featureAct),
})),
- })
- );
+ }));
+ };
+
const tokenGroups = groupToken(tokens);
const tokenGroupPositions = countTokenGroupPositions(tokenGroups);
const selectedTokenGroups = selectedTokenGroupIndices.map((i) => tokenGroups[i]);
- const selectedTokens = selectedTokenGroups.flatMap((t) => t);
+ const selectedTokens = sortTokenInfo(selectedTokenGroups.flatMap((t) => t));
const selectedTokenGroupPositions = selectedTokenGroupIndices.map((i) => tokenGroupPositions[i]);
const selectedTokenPositions = selectedTokenGroups.flatMap((t, i) =>
t.map((_, j) => selectedTokenGroupPositions[i] + j)
);
- const data = selectedTokens.map((token) =>
- Object.assign(
- {},
- ...token.logits.map((logits, j) => ({
- [`logits-${j}`]: logits.logits,
- [`logits-token-${j}`]: hex(logits),
- })),
- {
- name: hex(token),
- }
- )
- );
-
- const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"];
-
return (
{
})}
/>
- {selectedTokens.length > 0 && Detail of Selected Tokens:
}
+ {selectedTokens.length > 0 && (
+
+ Detail of {selectedTokens.length} Selected Token{selectedTokens.length > 1 ? "s" : ""}:
+
+ )}
{selectedTokens.map((token, i) => (
-
-
-
-
Token:
-
{hex(token)}
-
Position:
-
{selectedTokenPositions[i]}
-
-
-
-
Top Logits:
-
- {token.logits.map((logit, j) => (
-
- {hex(logit)}
-
- {logit.logits.toFixed(3)}
-
-
- ))}
-
+
+
+
Token:
+
{hex(token)}
+
Position:
+
{selectedTokenPositions[i]}
+
+
+ {token.saeInfo.map((saeInfo, j) => (
+ setSAESettings({ ...saeSettings, [saeInfo.name]: settings })}
+ />
+ ))}
+
{i < selectedTokens.length - 1 &&
}
))}
- {selectedTokens.length > 0 && (
-
-
-
-
-
-
-
- {selectedTokens[0].logits.map((_, i) => (
-
-
-
- ))}
-
-
- )}
+ {selectedTokens.length > 0 &&
}
);
};
const ModelCustomInputArea = () => {
const [customInput, setCustomInput] = useState
("");
+ const [doGenerate, setDoGenerate] = useState(true);
const [maxNewTokens, setMaxNewTokens] = useState(128);
const [topK, setTopK] = useState(50);
const [topP, setTopP] = useState(0.95);
- const [logitTopK, setLogitTopK] = useState(5);
+ const [selectedDictionaries, setSelectedDictionaries] = useState([]);
+ const [steerings, setSteerings] = useState<
+ {
+ sae: string | null;
+ featureIndex: number;
+ steeringType: "times" | "ablate" | "add" | "set";
+ steeringValue: number | null;
+ }[]
+ >([{ sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }]);
+
const [sample, setSample] = useState(null);
+
+ const [dictionariesState, fetchDictionaries] = useAsyncFn(async () => {
+ return await fetch(`${import.meta.env.VITE_BACKEND_URL}/dictionaries`)
+ .then(async (res) => await res.json())
+ .then((res) => z.array(z.string()).parse(res));
+ });
+
+ useMount(async () => {
+ await fetchDictionaries();
+ });
+
const [state, submit] = useAsyncFn(async () => {
if (!customInput) {
alert("Please enter your input.");
return;
}
- const sample = await fetch(
- `${import.meta.env.VITE_BACKEND_URL}/model/generate?input_text=${encodeURIComponent(
- customInput
- )}&max_new_tokens=${encodeURIComponent(maxNewTokens.toString())}&top_k=${encodeURIComponent(
- topK.toString()
- )}&top_p=${encodeURIComponent(topP.toString())}&return_logits_top_k=${encodeURIComponent(logitTopK.toString())}`,
- {
- method: "POST",
- headers: {
- Accept: "application/x-msgpack",
- },
- }
- )
+ const sample = await fetch(`${import.meta.env.VITE_BACKEND_URL}/model/generate`, {
+ method: "POST",
+ body: JSON.stringify(
+ snakecaseKeys({
+ inputText: customInput,
+ maxNewTokens: doGenerate ? maxNewTokens : 0,
+ topK,
+ topP,
+ saes: selectedDictionaries,
+ steerings: steerings.filter((s) => s.sae !== null),
+ })
+ ),
+ headers: {
+ Accept: "application/x-msgpack",
+ "Content-Type": "application/json",
+ },
+ })
.then(async (res) => {
if (!res.ok) {
throw new Error(await res.text());
@@ -170,15 +459,24 @@ const ModelCustomInputArea = () => {
)
.then((res) => ModelGenerationSchema.parse(res));
setSample(sample);
- }, [customInput]);
+ }, [doGenerate, customInput, maxNewTokens, topK, topP, selectedDictionaries]);
return (
Generation
+
Do generate:
+
+
+
+ {doGenerate
+ ? "Model will generate new contents based on the following configuration."
+ : "Model will only perceive the given input."}
+
+
Max new tokens:
{
/>
Top K:
{
/>
Top P:
setTopP(parseFloat(e.target.value))}
/>
-
Logit Top K:
-
+
SAEs:
+
setLogitTopK(parseInt(e.target.value))}
+ disabled={state.loading}
+ options={dictionariesState.value?.map((name) => ({ value: name, label: name })) || []}
+ commandProps={{
+ className: "col-span-3",
+ }}
+ hidePlaceholderWhenSelected
+ placeholder="Bind SAEs to the language model to see features activated in the generation."
+ value={selectedDictionaries.map((name) => ({ value: name, label: name }))}
+ onChange={(value) => setSelectedDictionaries(value.map((v) => v.value))}
+ emptyIndicator={
+ No dictionaries found.
+ }
/>
+ {selectedDictionaries.length > 0 &&
+ steerings.map((steering, i) => (
+
+ Steering {i + 1}:
+
+
+
{
+ setSteerings((prev) => prev.map((s, j) => (i === j ? { ...s, sae: value } : s)));
+ }}
+ options={selectedDictionaries.map((name) => ({ value: name, label: name }))}
+ disabled={state.loading}
+ placeholder="Select a dictionary to steer the generation."
+ commandPlaceholder="Search for a selected dictionary..."
+ emptyIndicator="No options found"
+ />
+ #
+ {
+ setSteerings((prev) =>
+ prev.map((s, j) => (i === j ? { ...s, featureIndex: parseInt(e.target.value) } : s))
+ );
+ }}
+ />
+ {
+ if (value === "times" || value === "add" || value === "set")
+ setSteerings((prev) =>
+ prev.map((s, j) => (i === j ? { ...s, steeringType: value, steeringValue: 1 } : s))
+ );
+ else if (value === "ablate")
+ setSteerings((prev) =>
+ prev.map((s, j) => (i === j ? { ...s, steeringType: value, steeringValue: null } : s))
+ );
+ }}
+ >
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ setSteerings((prev) =>
+ prev.map((s, j) => (i === j ? { ...s, steeringValue: parseFloat(e.target.value) } : s))
+ );
+ }}
+ />
+
+
+ steerings.length > 1
+ ? setSteerings((prev) => prev.filter((_, j) => i !== j))
+ : setSteerings([{ sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }])
+ }
+ >
+
+
+
+ setSteerings((prev) => [
+ ...prev.slice(0, i + 1),
+ { sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 },
+ ...prev.slice(i + 1),
+ ])
+ }
+ >
+
+
+
+
+ ))}
);
};
diff --git a/ui/src/components/ui/badge.tsx b/ui/src/components/ui/badge.tsx
new file mode 100644
index 0000000..f000e3e
--- /dev/null
+++ b/ui/src/components/ui/badge.tsx
@@ -0,0 +1,36 @@
+import * as React from "react"
+import { cva, type VariantProps } from "class-variance-authority"
+
+import { cn } from "@/lib/utils"
+
+const badgeVariants = cva(
+ "inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
+ {
+ variants: {
+ variant: {
+ default:
+ "border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
+ secondary:
+ "border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
+ destructive:
+ "border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
+ outline: "text-foreground",
+ },
+ },
+ defaultVariants: {
+ variant: "default",
+ },
+ }
+)
+
+export interface BadgeProps
+ extends React.HTMLAttributes,
+ VariantProps {}
+
+function Badge({ className, variant, ...props }: BadgeProps) {
+ return (
+
+ )
+}
+
+export { Badge, badgeVariants }
diff --git a/ui/src/components/ui/combobox.tsx b/ui/src/components/ui/combobox.tsx
new file mode 100644
index 0000000..ecf00d3
--- /dev/null
+++ b/ui/src/components/ui/combobox.tsx
@@ -0,0 +1,92 @@
+"use client";
+
+import * as React from "react";
+import { Check, ChevronsUpDown } from "lucide-react";
+
+import { cn } from "@/lib/utils";
+import { Button } from "@/components/ui/button";
+import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList } from "@/components/ui/command";
+import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
+
+export type ComboboxProps = {
+ value?: string | null;
+ onChange?: (value: string) => void;
+ options: { value: string; label: string }[];
+ placeholder?: string;
+ commandPlaceholder?: string;
+ emptyIndicator?: string;
+ className?: string;
+ disabled?: boolean;
+};
+
+export function Combobox({
+ value,
+ onChange,
+ options,
+ placeholder,
+ commandPlaceholder,
+ emptyIndicator,
+ className,
+ disabled,
+}: ComboboxProps) {
+ const [open, setOpen] = React.useState(false);
+ const [internalValue, setInternalValue] = React.useState((value ?? options[0]?.value) || null);
+
+ React.useEffect(() => {
+ if (value !== undefined) {
+ setInternalValue(value);
+ }
+ }, [value]);
+
+ const setValue = React.useCallback(
+ (value: string) => {
+ setInternalValue(value);
+ onChange?.(value);
+ },
+ [onChange]
+ );
+
+ return (
+
+
+
+ {internalValue ? options.find((option) => option.value === internalValue)?.label : placeholder || "Select..."}
+
+
+
+
+
+
+
+ {emptyIndicator || "No options found"}
+
+ {options.map((option) => (
+ {
+ setValue(currentValue === internalValue ? "" : currentValue);
+ setOpen(false);
+ }}
+ >
+
+ {option.label}
+
+ ))}
+
+
+
+
+
+ );
+}
diff --git a/ui/src/components/ui/command.tsx b/ui/src/components/ui/command.tsx
new file mode 100644
index 0000000..56a0979
--- /dev/null
+++ b/ui/src/components/ui/command.tsx
@@ -0,0 +1,153 @@
+import * as React from "react"
+import { type DialogProps } from "@radix-ui/react-dialog"
+import { Command as CommandPrimitive } from "cmdk"
+import { Search } from "lucide-react"
+
+import { cn } from "@/lib/utils"
+import { Dialog, DialogContent } from "@/components/ui/dialog"
+
+const Command = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+Command.displayName = CommandPrimitive.displayName
+
+interface CommandDialogProps extends DialogProps {}
+
+const CommandDialog = ({ children, ...props }: CommandDialogProps) => {
+ return (
+
+
+
+ {children}
+
+
+
+ )
+}
+
+const CommandInput = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+
+
+
+))
+
+CommandInput.displayName = CommandPrimitive.Input.displayName
+
+const CommandList = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+
+CommandList.displayName = CommandPrimitive.List.displayName
+
+const CommandEmpty = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>((props, ref) => (
+
+))
+
+CommandEmpty.displayName = CommandPrimitive.Empty.displayName
+
+const CommandGroup = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+
+CommandGroup.displayName = CommandPrimitive.Group.displayName
+
+const CommandSeparator = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+CommandSeparator.displayName = CommandPrimitive.Separator.displayName
+
+const CommandItem = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+
+CommandItem.displayName = CommandPrimitive.Item.displayName
+
+const CommandShortcut = ({
+ className,
+ ...props
+}: React.HTMLAttributes) => {
+ return (
+
+ )
+}
+CommandShortcut.displayName = "CommandShortcut"
+
+export {
+ Command,
+ CommandDialog,
+ CommandInput,
+ CommandList,
+ CommandEmpty,
+ CommandGroup,
+ CommandItem,
+ CommandShortcut,
+ CommandSeparator,
+}
diff --git a/ui/src/components/ui/data-table.tsx b/ui/src/components/ui/data-table.tsx
index 2e32c7d..b0a6a81 100644
--- a/ui/src/components/ui/data-table.tsx
+++ b/ui/src/components/ui/data-table.tsx
@@ -7,9 +7,10 @@ import { useState } from "react";
interface DataTableProps {
columns: ColumnDef[];
data: TData[];
+ pageSize?: number;
}
-export function DataTable({ columns, data }: DataTableProps) {
+export function DataTable({ columns, data, pageSize = 10 }: DataTableProps) {
const [page, setPage] = useState(1);
const table = useReactTable({
@@ -20,7 +21,7 @@ export function DataTable({ columns, data }: DataTableProps({ columns, data }: DataTableProps
{headerGroup.headers.map((header) => {
return (
-
+
{header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())}
);
@@ -49,7 +50,9 @@ export function DataTable({ columns, data }: DataTableProps (
{row.getVisibleCells().map((cell) => (
- {flexRender(cell.column.columnDef.cell, cell.getContext())}
+
+ {flexRender(cell.column.columnDef.cell, cell.getContext())}
+
))}
))
diff --git a/ui/src/components/ui/dialog.tsx b/ui/src/components/ui/dialog.tsx
new file mode 100644
index 0000000..c23630e
--- /dev/null
+++ b/ui/src/components/ui/dialog.tsx
@@ -0,0 +1,120 @@
+import * as React from "react"
+import * as DialogPrimitive from "@radix-ui/react-dialog"
+import { X } from "lucide-react"
+
+import { cn } from "@/lib/utils"
+
+const Dialog = DialogPrimitive.Root
+
+const DialogTrigger = DialogPrimitive.Trigger
+
+const DialogPortal = DialogPrimitive.Portal
+
+const DialogClose = DialogPrimitive.Close
+
+const DialogOverlay = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+DialogOverlay.displayName = DialogPrimitive.Overlay.displayName
+
+const DialogContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, children, ...props }, ref) => (
+
+
+
+ {children}
+
+
+ Close
+
+
+
+))
+DialogContent.displayName = DialogPrimitive.Content.displayName
+
+const DialogHeader = ({
+ className,
+ ...props
+}: React.HTMLAttributes) => (
+
+)
+DialogHeader.displayName = "DialogHeader"
+
+const DialogFooter = ({
+ className,
+ ...props
+}: React.HTMLAttributes) => (
+
+)
+DialogFooter.displayName = "DialogFooter"
+
+const DialogTitle = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+DialogTitle.displayName = DialogPrimitive.Title.displayName
+
+const DialogDescription = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+DialogDescription.displayName = DialogPrimitive.Description.displayName
+
+export {
+ Dialog,
+ DialogPortal,
+ DialogOverlay,
+ DialogClose,
+ DialogTrigger,
+ DialogContent,
+ DialogHeader,
+ DialogFooter,
+ DialogTitle,
+ DialogDescription,
+}
diff --git a/ui/src/components/ui/multiple-selector.tsx b/ui/src/components/ui/multiple-selector.tsx
new file mode 100644
index 0000000..8b1d5d7
--- /dev/null
+++ b/ui/src/components/ui/multiple-selector.tsx
@@ -0,0 +1,558 @@
+import { Command as CommandPrimitive, useCommandState } from "cmdk";
+import { X } from "lucide-react";
+import * as React from "react";
+import { forwardRef, useEffect } from "react";
+
+import { Badge } from "@/components/ui/badge";
+import { Command, CommandGroup, CommandItem, CommandList } from "@/components/ui/command";
+import { cn } from "@/lib/utils";
+
+export interface Option {
+ value: string;
+ label: string;
+ disable?: boolean;
+ /** fixed option that can't be removed. */
+ fixed?: boolean;
+ /** Group the options by providing key. */
+ [key: string]: string | boolean | undefined;
+}
+interface GroupOption {
+ [key: string]: Option[];
+}
+
+interface MultipleSelectorProps {
+ value?: Option[];
+ defaultOptions?: Option[];
+ /** manually controlled options */
+ options?: Option[];
+ placeholder?: string;
+ /** Loading component. */
+ loadingIndicator?: React.ReactNode;
+ /** Empty component. */
+ emptyIndicator?: React.ReactNode;
+ /** Debounce time for async search. Only work with `onSearch`. */
+ delay?: number;
+ /**
+ * Only work with `onSearch` prop. Trigger search when `onFocus`.
+ * For example, when user click on the input, it will trigger the search to get initial options.
+ **/
+ triggerSearchOnFocus?: boolean;
+ /** async search */
+ onSearch?: (value: string) => Promise;
+ onChange?: (options: Option[]) => void;
+ /** Limit the maximum number of selected options. */
+ maxSelected?: number;
+ /** When the number of selected options exceeds the limit, the onMaxSelected will be called. */
+ onMaxSelected?: (maxLimit: number) => void;
+ /** Hide the placeholder when there are options selected. */
+ hidePlaceholderWhenSelected?: boolean;
+ disabled?: boolean;
+ /** Group the options base on provided key. */
+ groupBy?: string;
+ className?: string;
+ badgeClassName?: string;
+ /**
+ * First item selected is a default behavior by cmdk. That is why the default is true.
+ * This is a workaround solution by add a dummy item.
+ *
+ * @reference: https://github.com/pacocoursey/cmdk/issues/171
+ */
+ selectFirstItem?: boolean;
+ /** Allow user to create option when there is no option matched. */
+ creatable?: boolean;
+ /** Props of `Command` */
+ commandProps?: React.ComponentPropsWithoutRef;
+ /** Props of `CommandInput` */
+ inputProps?: Omit<
+ React.ComponentPropsWithoutRef,
+ "value" | "placeholder" | "disabled"
+ >;
+ /** hide the clear all button. */
+ hideClearAllButton?: boolean;
+}
+
+export interface MultipleSelectorRef {
+ selectedValue: Option[];
+ input: HTMLInputElement;
+}
+
+export function useDebounce(value: T, delay?: number): T {
+ const [debouncedValue, setDebouncedValue] = React.useState(value);
+
+ useEffect(() => {
+ const timer = setTimeout(() => setDebouncedValue(value), delay || 500);
+
+ return () => {
+ clearTimeout(timer);
+ };
+ }, [value, delay]);
+
+ return debouncedValue;
+}
+
+function transToGroupOption(options: Option[], groupBy?: string) {
+ if (options.length === 0) {
+ return {};
+ }
+ if (!groupBy) {
+ return {
+ "": options,
+ };
+ }
+
+ const groupOption: GroupOption = {};
+ options.forEach((option) => {
+ const key = (option[groupBy] as string) || "";
+ if (!groupOption[key]) {
+ groupOption[key] = [];
+ }
+ groupOption[key].push(option);
+ });
+ return groupOption;
+}
+
+function removePickedOption(groupOption: GroupOption, picked: Option[]) {
+ const cloneOption = JSON.parse(JSON.stringify(groupOption)) as GroupOption;
+
+ for (const [key, value] of Object.entries(cloneOption)) {
+ cloneOption[key] = value.filter((val) => !picked.find((p) => p.value === val.value));
+ }
+ return cloneOption;
+}
+
+function isOptionsExist(groupOption: GroupOption, targetOption: Option[]) {
+ for (const [, value] of Object.entries(groupOption)) {
+ if (value.some((option) => targetOption.find((p) => p.value === option.value))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/**
+ * The `CommandEmpty` of shadcn/ui will cause the cmdk empty not rendering correctly.
+ * So we create one and copy the `Empty` implementation from `cmdk`.
+ *
+ * @reference: https://github.com/hsuanyi-chou/shadcn-ui-expansions/issues/34#issuecomment-1949561607
+ **/
+const CommandEmpty = forwardRef>(
+ ({ className, ...props }, forwardedRef) => {
+ const render = useCommandState((state) => state.filtered.count === 0);
+
+ if (!render) return null;
+
+ return (
+
+ );
+ }
+);
+
+CommandEmpty.displayName = "CommandEmpty";
+
+const MultipleSelector = React.forwardRef(
+ (
+ {
+ value,
+ onChange,
+ placeholder,
+ defaultOptions: arrayDefaultOptions = [],
+ options: arrayOptions,
+ delay,
+ onSearch,
+ loadingIndicator,
+ emptyIndicator,
+ maxSelected = Number.MAX_SAFE_INTEGER,
+ onMaxSelected,
+ hidePlaceholderWhenSelected,
+ disabled,
+ groupBy,
+ className,
+ badgeClassName,
+ selectFirstItem = true,
+ creatable = false,
+ triggerSearchOnFocus = false,
+ commandProps,
+ inputProps,
+ hideClearAllButton = false,
+ }: MultipleSelectorProps,
+ ref: React.Ref
+ ) => {
+ const inputRef = React.useRef(null);
+ const [open, setOpen] = React.useState(false);
+ const [onScrollbar, setOnScrollbar] = React.useState(false);
+ const [isLoading, setIsLoading] = React.useState(false);
+ const dropdownRef = React.useRef(null); // Added this
+
+ const [selected, setSelected] = React.useState(value || []);
+ const [options, setOptions] = React.useState(transToGroupOption(arrayDefaultOptions, groupBy));
+ const [inputValue, setInputValue] = React.useState("");
+ const debouncedSearchTerm = useDebounce(inputValue, delay || 500);
+
+ React.useImperativeHandle(
+ ref,
+ () => ({
+ selectedValue: [...selected],
+ input: inputRef.current as HTMLInputElement,
+ focus: () => inputRef.current?.focus(),
+ }),
+ [selected]
+ );
+
+ const handleClickOutside = (event: MouseEvent | TouchEvent) => {
+ if (
+ dropdownRef.current &&
+ !dropdownRef.current.contains(event.target as Node) &&
+ inputRef.current &&
+ !inputRef.current.contains(event.target as Node)
+ ) {
+ setOpen(false);
+ }
+ };
+
+ const handleUnselect = React.useCallback(
+ (option: Option) => {
+ const newOptions = selected.filter((s) => s.value !== option.value);
+ setSelected(newOptions);
+ onChange?.(newOptions);
+ },
+ [onChange, selected]
+ );
+
+ const handleKeyDown = React.useCallback(
+ (e: React.KeyboardEvent) => {
+ const input = inputRef.current;
+ if (input) {
+ if (e.key === "Delete" || e.key === "Backspace") {
+ if (input.value === "" && selected.length > 0) {
+ const lastSelectOption = selected[selected.length - 1];
+ // If last item is fixed, we should not remove it.
+ if (!lastSelectOption.fixed) {
+ handleUnselect(selected[selected.length - 1]);
+ }
+ }
+ }
+ // This is not a default behavior of the field
+ if (e.key === "Escape") {
+ input.blur();
+ }
+ }
+ },
+ [handleUnselect, selected]
+ );
+
+ useEffect(() => {
+ if (open) {
+ document.addEventListener("mousedown", handleClickOutside);
+ document.addEventListener("touchend", handleClickOutside);
+ } else {
+ document.removeEventListener("mousedown", handleClickOutside);
+ document.removeEventListener("touchend", handleClickOutside);
+ }
+
+ return () => {
+ document.removeEventListener("mousedown", handleClickOutside);
+ document.removeEventListener("touchend", handleClickOutside);
+ };
+ }, [open]);
+
+ useEffect(() => {
+ if (value) {
+ setSelected(value);
+ }
+ }, [value]);
+
+ useEffect(() => {
+ /** If `onSearch` is provided, do not trigger options updated. */
+ if (!arrayOptions || onSearch) {
+ return;
+ }
+ const newOption = transToGroupOption(arrayOptions || [], groupBy);
+ if (JSON.stringify(newOption) !== JSON.stringify(options)) {
+ setOptions(newOption);
+ }
+ }, [arrayDefaultOptions, arrayOptions, groupBy, onSearch, options]);
+
+ useEffect(() => {
+ const doSearch = async () => {
+ setIsLoading(true);
+ const res = await onSearch?.(debouncedSearchTerm);
+ setOptions(transToGroupOption(res || [], groupBy));
+ setIsLoading(false);
+ };
+
+ const exec = async () => {
+ if (!onSearch || !open) return;
+
+ if (triggerSearchOnFocus) {
+ await doSearch();
+ }
+
+ if (debouncedSearchTerm) {
+ await doSearch();
+ }
+ };
+
+ void exec();
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [debouncedSearchTerm, groupBy, open, triggerSearchOnFocus]);
+
+ const CreatableItem = () => {
+ if (!creatable) return undefined;
+ if (
+ isOptionsExist(options, [{ value: inputValue, label: inputValue }]) ||
+ selected.find((s) => s.value === inputValue)
+ ) {
+ return undefined;
+ }
+
+ const Item = (
+ {
+ e.preventDefault();
+ e.stopPropagation();
+ }}
+ onSelect={(value: string) => {
+ if (selected.length >= maxSelected) {
+ onMaxSelected?.(selected.length);
+ return;
+ }
+ setInputValue("");
+ const newOptions = [...selected, { value, label: value }];
+ setSelected(newOptions);
+ onChange?.(newOptions);
+ }}
+ >
+ {`Create "${inputValue}"`}
+
+ );
+
+ // For normal creatable
+ if (!onSearch && inputValue.length > 0) {
+ return Item;
+ }
+
+ // For async search creatable. avoid showing creatable item before loading at first.
+ if (onSearch && debouncedSearchTerm.length > 0 && !isLoading) {
+ return Item;
+ }
+
+ return undefined;
+ };
+
+ const EmptyItem = React.useCallback(() => {
+ if (!emptyIndicator) return undefined;
+
+ // For async search that showing emptyIndicator
+ if (onSearch && !creatable && Object.keys(options).length === 0) {
+ return (
+
+ {emptyIndicator}
+
+ );
+ }
+
+ return {emptyIndicator} ;
+ }, [creatable, emptyIndicator, onSearch, options]);
+
+ const selectables = React.useMemo(() => removePickedOption(options, selected), [options, selected]);
+
+ /** Avoid Creatable Selector freezing or lagging when paste a long string. */
+ const commandFilter = React.useCallback(() => {
+ if (commandProps?.filter) {
+ return commandProps.filter;
+ }
+
+ if (creatable) {
+ return (value: string, search: string) => {
+ return value.toLowerCase().includes(search.toLowerCase()) ? 1 : -1;
+ };
+ }
+ // Using default filter in `cmdk`. We don't have to provide it.
+ return undefined;
+ }, [creatable, commandProps?.filter]);
+
+ return (
+ {
+ handleKeyDown(e);
+ commandProps?.onKeyDown?.(e);
+ }}
+ className={cn("h-auto overflow-visible bg-transparent", commandProps?.className)}
+ shouldFilter={commandProps?.shouldFilter !== undefined ? commandProps.shouldFilter : !onSearch} // When onSearch is provided, we don't want to filter the options. You can still override it.
+ filter={commandFilter()}
+ >
+ {
+ if (disabled) return;
+ inputRef.current?.focus();
+ }}
+ >
+
+ {selected.map((option) => {
+ return (
+
+ {option.label}
+ {
+ if (e.key === "Enter") {
+ handleUnselect(option);
+ }
+ }}
+ onMouseDown={(e) => {
+ e.preventDefault();
+ e.stopPropagation();
+ }}
+ onClick={() => handleUnselect(option)}
+ >
+
+
+
+ );
+ })}
+ {/* Avoid having the "Search" Icon */}
+ {
+ setInputValue(value);
+ inputProps?.onValueChange?.(value);
+ }}
+ onBlur={(event) => {
+ if (!onScrollbar) {
+ setOpen(false);
+ }
+ inputProps?.onBlur?.(event);
+ }}
+ onFocus={(event) => {
+ setOpen(true);
+ triggerSearchOnFocus && onSearch?.(debouncedSearchTerm);
+ inputProps?.onFocus?.(event);
+ }}
+ placeholder={hidePlaceholderWhenSelected && selected.length !== 0 ? "" : placeholder}
+ className={cn(
+ "flex-1 bg-transparent outline-none placeholder:text-muted-foreground",
+ {
+ "w-full": hidePlaceholderWhenSelected,
+ "px-3 py-2": selected.length === 0,
+ "ml-1": selected.length !== 0,
+ },
+ inputProps?.className
+ )}
+ />
+ {
+ setSelected(selected.filter((s) => s.fixed));
+ onChange?.(selected.filter((s) => s.fixed));
+ }}
+ className={cn(
+ "absolute right-0 h-6 w-6 p-0",
+ (hideClearAllButton ||
+ disabled ||
+ selected.length < 1 ||
+ selected.filter((s) => s.fixed).length === selected.length) &&
+ "hidden"
+ )}
+ >
+
+
+
+
+
+ {open && (
+ {
+ setOnScrollbar(false);
+ }}
+ onMouseEnter={() => {
+ setOnScrollbar(true);
+ }}
+ onMouseUp={() => {
+ inputRef.current?.focus();
+ }}
+ >
+ {isLoading ? (
+ <>{loadingIndicator}>
+ ) : (
+ <>
+ {EmptyItem()}
+ {CreatableItem()}
+ {!selectFirstItem && }
+ {Object.entries(selectables).map(([key, dropdowns]) => (
+
+ <>
+ {dropdowns.map((option) => {
+ return (
+ {
+ e.preventDefault();
+ e.stopPropagation();
+ }}
+ onSelect={() => {
+ if (selected.length >= maxSelected) {
+ onMaxSelected?.(selected.length);
+ return;
+ }
+ setInputValue("");
+ const newOptions = [...selected, option];
+ setSelected(newOptions);
+ onChange?.(newOptions);
+ }}
+ className={cn("cursor-pointer", option.disable && "cursor-default text-muted-foreground")}
+ >
+ {option.label}
+
+ );
+ })}
+ >
+
+ ))}
+ >
+ )}
+
+ )}
+
+
+ );
+ }
+);
+
+MultipleSelector.displayName = "MultipleSelector";
+export default MultipleSelector;
diff --git a/ui/src/components/ui/pagination.tsx b/ui/src/components/ui/pagination.tsx
index 8aa84b6..b189c90 100644
--- a/ui/src/components/ui/pagination.tsx
+++ b/ui/src/components/ui/pagination.tsx
@@ -86,13 +86,15 @@ export const AppPagination = ({
page,
setPage,
maxPage,
+ className,
}: {
page: number;
setPage: React.Dispatch>;
maxPage: number;
+ className?: string;
}) => {
return (
-
+
setPage((prev) => Math.max(1, prev - 1))} />
diff --git a/ui/src/components/ui/popover.tsx b/ui/src/components/ui/popover.tsx
new file mode 100644
index 0000000..bbba7e0
--- /dev/null
+++ b/ui/src/components/ui/popover.tsx
@@ -0,0 +1,29 @@
+import * as React from "react"
+import * as PopoverPrimitive from "@radix-ui/react-popover"
+
+import { cn } from "@/lib/utils"
+
+const Popover = PopoverPrimitive.Root
+
+const PopoverTrigger = PopoverPrimitive.Trigger
+
+const PopoverContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, align = "center", sideOffset = 4, ...props }, ref) => (
+
+
+
+))
+PopoverContent.displayName = PopoverPrimitive.Content.displayName
+
+export { Popover, PopoverTrigger, PopoverContent }
diff --git a/ui/src/components/ui/switch.tsx b/ui/src/components/ui/switch.tsx
new file mode 100644
index 0000000..aa58baa
--- /dev/null
+++ b/ui/src/components/ui/switch.tsx
@@ -0,0 +1,27 @@
+import * as React from "react"
+import * as SwitchPrimitives from "@radix-ui/react-switch"
+
+import { cn } from "@/lib/utils"
+
+const Switch = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+
+
+))
+Switch.displayName = SwitchPrimitives.Root.displayName
+
+export { Switch }
diff --git a/ui/src/components/ui/tooltip.tsx b/ui/src/components/ui/tooltip.tsx
new file mode 100644
index 0000000..e121f0a
--- /dev/null
+++ b/ui/src/components/ui/tooltip.tsx
@@ -0,0 +1,28 @@
+import * as React from "react"
+import * as TooltipPrimitive from "@radix-ui/react-tooltip"
+
+import { cn } from "@/lib/utils"
+
+const TooltipProvider = TooltipPrimitive.Provider
+
+const Tooltip = TooltipPrimitive.Root
+
+const TooltipTrigger = TooltipPrimitive.Trigger
+
+const TooltipContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, sideOffset = 4, ...props }, ref) => (
+
+))
+TooltipContent.displayName = TooltipPrimitive.Content.displayName
+
+export { Tooltip, TooltipTrigger, TooltipContent, TooltipProvider }
diff --git a/ui/src/tanstack.d.ts b/ui/src/tanstack.d.ts
new file mode 100644
index 0000000..5c3673f
--- /dev/null
+++ b/ui/src/tanstack.d.ts
@@ -0,0 +1,9 @@
+import "@tanstack/react-table";
+
+declare module "@tanstack/react-table" {
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ interface ColumnMeta {
+ cellClassName?: string;
+ headerClassName?: string;
+ }
+}
diff --git a/ui/src/types/model.ts b/ui/src/types/model.ts
index d01af56..93710d5 100644
--- a/ui/src/types/model.ts
+++ b/ui/src/types/model.ts
@@ -5,6 +5,14 @@ export const ModelGenerationSchema = z.object({
inputMask: z.array(z.number()),
logits: z.array(z.array(z.number())),
logitsTokens: z.array(z.array(z.instanceof(Uint8Array))),
+ saeInfo: z.array(
+ z.object({
+ name: z.string(),
+ featureActsIndices: z.array(z.array(z.number())),
+ featureActs: z.array(z.array(z.number())),
+ maxFeatureActs: z.array(z.array(z.number())),
+ })
+ ),
});
export type ModelGeneration = z.infer;