diff --git a/ui/bun.lockb b/ui/bun.lockb index 2194dfdd..3bd98f13 100755 Binary files a/ui/bun.lockb and b/ui/bun.lockb differ diff --git a/ui/package.json b/ui/package.json index 0c8a2111..717bf84e 100644 --- a/ui/package.json +++ b/ui/package.json @@ -12,18 +12,23 @@ "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", "@radix-ui/react-accordion": "^1.1.2", + "@radix-ui/react-dialog": "^1.1.1", "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-hover-card": "^1.0.7", "@radix-ui/react-label": "^2.0.2", + "@radix-ui/react-popover": "^1.1.1", "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", "@radix-ui/react-slot": "^1.0.2", + "@radix-ui/react-switch": "^1.1.0", "@radix-ui/react-tabs": "^1.0.4", "@radix-ui/react-toggle": "^1.0.3", + "@radix-ui/react-tooltip": "^1.1.2", "@tanstack/react-table": "^8.15.3", "camelcase-keys": "^9.1.3", "class-variance-authority": "^0.7.0", "clsx": "^2.1.0", + "cmdk": "1.0.0", "lucide-react": "^0.358.0", "plotly.js": "^2.30.1", "react": "^18.2.0", @@ -32,6 +37,7 @@ "react-router-dom": "^6.22.3", "react-use": "^17.5.0", "recharts": "^2.12.7", + "snakecase-keys": "^8.0.1", "tailwind-merge": "^2.2.1", "tailwindcss-animate": "^1.0.7", "zod": "^3.22.4", diff --git a/ui/src/components/model/model-card.tsx b/ui/src/components/model/model-card.tsx index ce14e7b6..d89ab70f 100644 --- a/ui/src/components/model/model-card.tsx +++ b/ui/src/components/model/model-card.tsx @@ -1,9 +1,10 @@ -import { Fragment, useEffect, useState } from "react"; +import { Fragment, useEffect, useMemo, useState } from "react"; import { Button } from "../ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; import { Textarea } from "../ui/textarea"; -import { useAsyncFn } from "react-use"; +import { useAsyncFn, useMount } from "react-use"; import camelcaseKeys from "camelcase-keys"; +import snakecaseKeys from "snakecase-keys"; import { decode } from "@msgpack/msgpack"; import { ModelGeneration, ModelGenerationSchema } from "@/types/model"; import { Sample } from "../app/sample"; @@ -14,53 +15,330 @@ import { Input } from "../ui/input"; import { countTokenGroupPositions, groupToken, hex } from "@/utils/token"; import { getAccentClassname } from "@/utils/style"; import { Separator } from "../ui/separator"; +import MultipleSelector from "../ui/multiple-selector"; +import { z } from "zod"; +import { Combobox } from "../ui/combobox"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "../ui/select"; +import { Ban, MoreHorizontal, Plus, Trash2, Wrench, X } from "lucide-react"; +import { ColumnDef } from "@tanstack/react-table"; +import { FeatureLinkWithPreview } from "../app/feature-preview"; +import { DataTable } from "../ui/data-table"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "../ui/dropdown-menu"; +import { useNavigate } from "react-router-dom"; +import { Switch } from "../ui/switch"; +import { Label as SLabel } from "../ui/label"; +import { Toggle } from "../ui/toggle"; -const ModelSample = ({ sample }: { sample: ModelGeneration }) => { +const SAEInfo = ({ + saeInfo, + saeSettings, + onSteerFeature, + setSAESettings, +}: { + saeInfo: { + name: string; + featureActs: { + featureActIndex: number; + featureAct: number; + maxFeatureAct: number; + }[]; + }; + saeSettings: { sortedBySum: boolean }; + onSteerFeature?: (name: string, featureIndex: number) => void; + setSAESettings: (settings: { sortedBySum: boolean }) => void; +}) => { + const navigate = useNavigate(); + + const columns: ColumnDef<{ + featureActIndex: number; + featureAct: number; + maxFeatureAct: number; + }>[] = [ + { + accessorKey: "featureActIndex", + header: () => ( +
+ Feature +
+ ), + cell: ({ row }) => ( +
+ +
+ ), + }, + { + accessorKey: "featureAct", + header: () => ( +
+ Activation +
+ ), + cell: ({ row }) => ( +
+ {row.original.featureAct.toFixed(3)} +
+ ), + }, + { + id: "actions", + enableHiding: false, + cell: ({ row }) => ( + + + + + + Actions + onSteerFeature?.(saeInfo.name, row.original.featureActIndex)}> + Steer feature + + + { + navigate( + `/features?dictionary=${encodeURIComponent(saeInfo.name)}&featureIndex=${ + row.original.featureActIndex + }` + ); + }} + > + View feature + + + + ), + meta: { + headerClassName: "w-16", + cellClassName: "py-0", + }, + }, + ]; + + return ( +
+
+
Features from {saeInfo.name}:
+ setSAESettings({ ...saeSettings, sortedBySum: pressed })} + > + Sort by sum + +
+ +
+ ); +}; + +const LogitsInfo = ({ + logits, +}: { + logits: { + logits: number; + token: Uint8Array; + }[]; +}) => { + const maxLogits = Math.max(...logits.map((logit) => logit.logits)); + const columns: ColumnDef<{ + logits: number; + token: Uint8Array; + }>[] = [ + { + accessorKey: "token", + header: () => ( +
+ Token +
+ ), + cell: ({ row }) => ( +
+ {hex(row.original)} +
+ ), + }, + { + accessorKey: "logits", + header: () => ( +
+ Logits +
+ ), + cell: ({ row }) => ( +
+ {row.original.logits.toFixed(3)} +
+ ), + }, + ]; + + return ( +
+
Logits:
+ +
+ ); +}; + +const LogitsBarChart = ({ + tokens, +}: { + tokens: { + logits: { + logits: number; + token: Uint8Array; + }[]; + token: Uint8Array; + }[]; +}) => { + const data = tokens.map((token) => + Object.assign( + {}, + ...token.logits.map((logit, j) => ({ + [`logits-${j}`]: logit.logits, + [`logits-token-${j}`]: hex(logit), + })), + { + name: hex(token), + } + ) + ); + + const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"]; + + return ( + + + + + + + {tokens[0].logits.slice(0, 5).map((_, i) => ( + + + + ))} + + + ); +}; + +const ModelSample = ({ + sample, + onSteerFeature, +}: { + sample: ModelGeneration; + onSteerFeature?: (name: string, featureIndex: number) => void; +}) => { const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState([]); const toggleSelectedTokenGroupIndex = (tokenGroupIndex: number) => { setSelectedTokenGroupIndices((prev) => - prev.includes(tokenGroupIndex) ? prev.filter((t) => t !== tokenGroupIndex) : [...prev, tokenGroupIndex] + prev.includes(tokenGroupIndex) + ? prev.filter((t) => t !== tokenGroupIndex) + : [...prev, tokenGroupIndex].sort((a, b) => a - b) ); }; + const [saeSettings, setSAESettings] = useState<{ [name: string]: { sortedBySum: boolean } }>({}); + const getSAESettings = (name: string) => saeSettings[name] || { sortedBySum: false }; + useEffect(() => { setSelectedTokenGroupIndices([]); }, [sample]); - const tokens = zip(sample.context, sample.inputMask, sample.logits, sample.logitsTokens).map( - ([token, inputMask, logits, logitsTokens]) => ({ - token, - inputMask, - logits: zip(logits, logitsTokens).map(([logits, token]) => ({ - logits, + const tokens = useMemo(() => { + const saeInfo = + sample.saeInfo.length > 0 + ? zip( + ...sample.saeInfo.map((sae) => + zip(sae.featureActsIndices, sae.featureActs, sae.maxFeatureActs).map( + ([featureActIndex, featureAct, maxFeatureAct]) => ({ + name: sae.name, + featureActs: zip(featureActIndex, featureAct, maxFeatureAct).map( + ([featureActIndex, featureAct, maxFeatureAct]) => ({ + featureActIndex, + featureAct, + maxFeatureAct, + }) + ), + }) + ) + ) + ) + : sample.context.map(() => []); + + return zip(sample.context, sample.inputMask, sample.logits, sample.logitsTokens, saeInfo).map( + ([token, inputMask, logits, logitsTokens, saeInfo]) => ({ token, + inputMask, + logits: zip(logits, logitsTokens).map(([logits, token]) => ({ + logits, + token, + })), + saeInfo, + }) + ); + }, [sample]); + + type Token = (typeof tokens)[0]; + + const sortTokenInfo = (tokens: Token[]) => { + const featureActSum = tokens.reduce((acc, token) => { + token.saeInfo.forEach((saeInfo) => { + saeInfo.featureActs.forEach((featureAct) => { + acc[saeInfo.name] = acc[saeInfo.name] || {}; + acc[saeInfo.name][featureAct.featureActIndex.toString()] = + acc[saeInfo.name][featureAct.featureActIndex.toString()] || 0; + acc[saeInfo.name][featureAct.featureActIndex.toString()] += featureAct.featureAct; + }); + }); + return acc; + }, {} as { [name: string]: { [featureIndex: string]: number } }); + + return tokens.map((token) => ({ + ...token, + logits: token.logits.sort((a, b) => b.logits - a.logits), + saeInfo: token.saeInfo.map((saeInfo) => ({ + ...saeInfo, + featureActs: getSAESettings(saeInfo.name).sortedBySum + ? saeInfo.featureActs.sort( + (a, b) => + featureActSum[saeInfo.name][b.featureActIndex.toString()] - + featureActSum[saeInfo.name][a.featureActIndex.toString()] + ) + : saeInfo.featureActs.sort((a, b) => b.featureAct - a.featureAct), })), - }) - ); + })); + }; + const tokenGroups = groupToken(tokens); const tokenGroupPositions = countTokenGroupPositions(tokenGroups); const selectedTokenGroups = selectedTokenGroupIndices.map((i) => tokenGroups[i]); - const selectedTokens = selectedTokenGroups.flatMap((t) => t); + const selectedTokens = sortTokenInfo(selectedTokenGroups.flatMap((t) => t)); const selectedTokenGroupPositions = selectedTokenGroupIndices.map((i) => tokenGroupPositions[i]); const selectedTokenPositions = selectedTokenGroups.flatMap((t, i) => t.map((_, j) => selectedTokenGroupPositions[i] + j) ); - const data = selectedTokens.map((token) => - Object.assign( - {}, - ...token.logits.map((logits, j) => ({ - [`logits-${j}`]: logits.logits, - [`logits-token-${j}`]: hex(logits), - })), - { - name: hex(token), - } - ) - ); - - const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"]; - return (
{ })} /> - {selectedTokens.length > 0 &&

Detail of Selected Tokens:

} + {selectedTokens.length > 0 && ( +

+ Detail of {selectedTokens.length} Selected Token{selectedTokens.length > 1 ? "s" : ""}: +

+ )} {selectedTokens.map((token, i) => ( -
-
-
-
Token:
-
{hex(token)}
-
Position:
-
{selectedTokenPositions[i]}
-
-
-
-

Top Logits:

-
- {token.logits.map((logit, j) => ( - -
{hex(logit)}
-
- {logit.logits.toFixed(3)} -
-
- ))} -
+
+
+
Token:
+
{hex(token)}
+
Position:
+
{selectedTokenPositions[i]}
+
+ + {token.saeInfo.map((saeInfo, j) => ( + setSAESettings({ ...saeSettings, [saeInfo.name]: settings })} + /> + ))} +
{i < selectedTokens.length - 1 && } ))} - {selectedTokens.length > 0 && ( - - - - - - - {selectedTokens[0].logits.map((_, i) => ( - - - - ))} - - - )} + {selectedTokens.length > 0 && }
); }; const ModelCustomInputArea = () => { const [customInput, setCustomInput] = useState(""); + const [doGenerate, setDoGenerate] = useState(true); const [maxNewTokens, setMaxNewTokens] = useState(128); const [topK, setTopK] = useState(50); const [topP, setTopP] = useState(0.95); - const [logitTopK, setLogitTopK] = useState(5); + const [selectedDictionaries, setSelectedDictionaries] = useState([]); + const [steerings, setSteerings] = useState< + { + sae: string | null; + featureIndex: number; + steeringType: "times" | "ablate" | "add" | "set"; + steeringValue: number | null; + }[] + >([{ sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }]); + const [sample, setSample] = useState(null); + + const [dictionariesState, fetchDictionaries] = useAsyncFn(async () => { + return await fetch(`${import.meta.env.VITE_BACKEND_URL}/dictionaries`) + .then(async (res) => await res.json()) + .then((res) => z.array(z.string()).parse(res)); + }); + + useMount(async () => { + await fetchDictionaries(); + }); + const [state, submit] = useAsyncFn(async () => { if (!customInput) { alert("Please enter your input."); return; } - const sample = await fetch( - `${import.meta.env.VITE_BACKEND_URL}/model/generate?input_text=${encodeURIComponent( - customInput - )}&max_new_tokens=${encodeURIComponent(maxNewTokens.toString())}&top_k=${encodeURIComponent( - topK.toString() - )}&top_p=${encodeURIComponent(topP.toString())}&return_logits_top_k=${encodeURIComponent(logitTopK.toString())}`, - { - method: "POST", - headers: { - Accept: "application/x-msgpack", - }, - } - ) + const sample = await fetch(`${import.meta.env.VITE_BACKEND_URL}/model/generate`, { + method: "POST", + body: JSON.stringify( + snakecaseKeys({ + inputText: customInput, + maxNewTokens: doGenerate ? maxNewTokens : 0, + topK, + topP, + saes: selectedDictionaries, + steerings: steerings.filter((s) => s.sae !== null), + }) + ), + headers: { + Accept: "application/x-msgpack", + "Content-Type": "application/json", + }, + }) .then(async (res) => { if (!res.ok) { throw new Error(await res.text()); @@ -170,15 +459,24 @@ const ModelCustomInputArea = () => { ) .then((res) => ModelGenerationSchema.parse(res)); setSample(sample); - }, [customInput]); + }, [doGenerate, customInput, maxNewTokens, topK, topP, selectedDictionaries]); return (

Generation

+ Do generate: +
+ + + {doGenerate + ? "Model will generate new contents based on the following configuration." + : "Model will only perceive the given input."} + +
Max new tokens: { /> Top K: { /> Top P: setTopP(parseFloat(e.target.value))} /> - Logit Top K: - + SAEs: + setLogitTopK(parseInt(e.target.value))} + disabled={state.loading} + options={dictionariesState.value?.map((name) => ({ value: name, label: name })) || []} + commandProps={{ + className: "col-span-3", + }} + hidePlaceholderWhenSelected + placeholder="Bind SAEs to the language model to see features activated in the generation." + value={selectedDictionaries.map((name) => ({ value: name, label: name }))} + onChange={(value) => setSelectedDictionaries(value.map((v) => v.value))} + emptyIndicator={ +

No dictionaries found.

+ } /> + {selectedDictionaries.length > 0 && + steerings.map((steering, i) => ( + + Steering {i + 1}: +
+
+ { + setSteerings((prev) => prev.map((s, j) => (i === j ? { ...s, sae: value } : s))); + }} + options={selectedDictionaries.map((name) => ({ value: name, label: name }))} + disabled={state.loading} + placeholder="Select a dictionary to steer the generation." + commandPlaceholder="Search for a selected dictionary..." + emptyIndicator="No options found" + /> + # + { + setSteerings((prev) => + prev.map((s, j) => (i === j ? { ...s, featureIndex: parseInt(e.target.value) } : s)) + ); + }} + /> + + { + setSteerings((prev) => + prev.map((s, j) => (i === j ? { ...s, steeringValue: parseFloat(e.target.value) } : s)) + ); + }} + /> +
+
+ steerings.length > 1 + ? setSteerings((prev) => prev.filter((_, j) => i !== j)) + : setSteerings([{ sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }]) + } + > + +
+
+ setSteerings((prev) => [ + ...prev.slice(0, i + 1), + { sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }, + ...prev.slice(i + 1), + ]) + } + > + +
+
+
+ ))}