diff --git a/server/app.py b/server/app.py index c9b9ad2..049edf3 100644 --- a/server/app.py +++ b/server/app.py @@ -298,16 +298,16 @@ def dictionary_custom_input(dictionary_name: str, input_text: str): return Response(content=msgpack.packb(sample), media_type="application/x-msgpack") @app.post("/model/generate") -def model_generate(input_text: str, max_length: int = 128, top_k: int = 50, top_p: float = 0.95): +def model_generate(input_text: str, max_new_tokens: int = 128, top_k: int = 50, top_p: float = 0.95, return_logits_top_k: int = 5): dictionaries = client.list_dictionaries(dictionary_series=dictionary_series) assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred." model = get_model(dictionaries[0]) with torch.no_grad(): input = model.to_tokens(input_text, prepend_bos=False) - output = model.generate(input, max_length=max_length, top_k=top_k, top_p=top_p) + output = model.generate(input, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p) output = output.clone() logits = model.forward(output) - logits_topk = [torch.topk(l, top_k) for l in logits[0]] + logits_topk = [torch.topk(l, return_logits_top_k) for l in logits[0]] result = { "context": [ bytearray([byte_decoder[c] for c in t]) @@ -327,10 +327,10 @@ def model_generate(input_text: str, max_length: int = 128, top_k: int = 50, top_ @app.post("/dictionaries/{dictionary_name}/features/{feature_index}/interpret") def feature_interpretation( - dictionary_name: str, - feature_index: int, - type: str, - custom_interpretation: str | None = None, + dictionary_name: str, + feature_index: int, + type: str, + custom_interpretation: str | None = None, ): model = get_model(dictionary_name) if type == "custom": diff --git a/ui/bun.lockb b/ui/bun.lockb index 6ee4bfe..2194dfd 100755 Binary files a/ui/bun.lockb and b/ui/bun.lockb differ diff --git a/ui/package.json b/ui/package.json index 93486e7..0c8a211 100644 --- a/ui/package.json +++ b/ui/package.json @@ -31,6 +31,7 @@ "react-plotly.js": "^2.6.0", "react-router-dom": "^6.22.3", "react-use": "^17.5.0", + "recharts": "^2.12.7", "tailwind-merge": "^2.2.1", "tailwindcss-animate": "^1.0.7", "zod": "^3.22.4", diff --git a/ui/src/components/model/model-card.tsx b/ui/src/components/model/model-card.tsx index 6774e4c..81534d3 100644 --- a/ui/src/components/model/model-card.tsx +++ b/ui/src/components/model/model-card.tsx @@ -2,22 +2,198 @@ import { useState } from "react"; import { Button } from "../ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; import { Textarea } from "../ui/textarea"; +import { useAsyncFn } from "react-use"; +import camelcaseKeys from "camelcase-keys"; +import { decode } from "@msgpack/msgpack"; +import { ModelGeneration, ModelGenerationSchema } from "@/types/model"; +import { SimpleSampleArea } from "../app/sample"; +import { cn } from "@/lib/utils"; +import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Label, LabelList, ResponsiveContainer } from "recharts"; +import { mergeUint8Arrays, zip } from "@/utils/array"; +import { Input } from "../ui/input"; + +const ModelSampleArea = ({ sample }: { sample: ModelGeneration }) => { + const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState([]); + const toggleSelectedTokenGroupIndex = (tokenGroupIndex: number) => { + setSelectedTokenGroupIndices((prev) => + prev.includes(tokenGroupIndex) ? prev.filter((t) => t !== tokenGroupIndex) : [...prev, tokenGroupIndex] + ); + }; + + const decoder = new TextDecoder("utf-8", { fatal: true }); + const tokens = sample.context.map((token, i) => ({ + token, + inputMask: sample.inputMask[i], + logits: zip(sample.logits[i], sample.logitsTokens[i]).map(([logits, logitsTokens]) => ({ + logits, + logitsTokens, + })), + })); + const tokenGroups = tokens.reduce<[(typeof tokens)[], typeof tokens]>( + ([groups, currentGroup], token) => { + const newGroup = [...currentGroup, token]; + try { + decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token))); + return [[...groups, newGroup], []]; + } catch { + return [groups, newGroup]; + } + }, + [[], []] + )[0]; + const selectedTokenGroups = selectedTokenGroupIndices.map((i) => tokenGroups[i]); + const selectedTokens = selectedTokenGroups.flatMap((t) => t); + + const data = selectedTokens.map((token) => + Object.assign( + {}, + ...token.logits.map((logits, j) => ({ + [`logits-${j}`]: logits.logits, + [`logits-token-${j}`]: logits.logitsTokens.reduce( + (acc, b) => + b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`, + "" + ), + })), + { + name: token.token.reduce( + (acc, b) => + b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`, + "" + ), + } + ) + ); + + const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"]; + + return ( +
+ + cn( + "hover:shadow-lg hover:text-gray-600 cursor-pointer", + selectedTokenGroupIndices.some((t) => t === tokenIndex) && "bg-orange-500" + ) + } + tokenGroupProps={(_, i) => ({ + onClick: () => toggleSelectedTokenGroupIndex(i), + })} + /> + + {selectedTokens.length > 0 && ( + + + + + + + {selectedTokens[0].logits.map((_, i) => ( + + + + ))} + + + )} +
+ ); +}; const ModelCustomInputArea = () => { const [customInput, setCustomInput] = useState(""); - const submit = async () => {}; - const disabled = false; + const [maxNewTokens, setMaxNewTokens] = useState(128); + const [topK, setTopK] = useState(50); + const [topP, setTopP] = useState(0.95); + const [logitTopK, setLogitTopK] = useState(5); + const [sample, setSample] = useState(null); + const [state, submit] = useAsyncFn(async () => { + if (!customInput) { + alert("Please enter your input."); + return; + } + const sample = await fetch( + `${import.meta.env.VITE_BACKEND_URL}/model/generate?input_text=${encodeURIComponent( + customInput + )}&max_new_tokens=${encodeURIComponent(maxNewTokens.toString())}&top_k=${encodeURIComponent( + topK.toString() + )}&top_p=${encodeURIComponent(topP.toString())}&return_logits_top_k=${encodeURIComponent(logitTopK.toString())}`, + { + method: "POST", + headers: { + Accept: "application/x-msgpack", + }, + } + ) + .then(async (res) => { + if (!res.ok) { + throw new Error(await res.text()); + } + return res; + }) + .then(async (res) => await res.arrayBuffer()) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .then((res) => decode(new Uint8Array(res)) as any) + .then((res) => + camelcaseKeys(res, { + deep: true, + stopPaths: ["context", "logits_tokens"], + }) + ) + .then((res) => ModelGenerationSchema.parse(res)); + setSample(sample); + }, [customInput]); + return (
-

Custom Input

+

Generation

+
+ Max new tokens: + setMaxNewTokens(parseInt(e.target.value))} + /> + Top K: + setTopK(parseInt(e.target.value))} + /> + Top P: + setTopP(parseFloat(e.target.value))} + /> + Logit Top K: + setLogitTopK(parseInt(e.target.value))} + /> +