Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Page with Generation Section #45

Merged
merged 4 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,40 @@ def dictionary_custom_input(dictionary_name: str, input_text: str):

return Response(content=msgpack.packb(sample), media_type="application/x-msgpack")

@app.post("/model/generate")
def model_generate(input_text: str, max_new_tokens: int = 128, top_k: int = 50, top_p: float = 0.95, return_logits_top_k: int = 5):
dictionaries = client.list_dictionaries(dictionary_series=dictionary_series)
assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred."
model = get_model(dictionaries[0])
with torch.no_grad():
input = model.to_tokens(input_text, prepend_bos=False)
output = model.generate(input, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p)
output = output.clone()
logits = model.forward(output)
logits_topk = [torch.topk(l, return_logits_top_k) for l in logits[0]]
result = {
"context": [
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(output[0])
],
"logits": [l.values.cpu().numpy().tolist() for l in logits_topk],
"logits_tokens": [
[
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(l.indices)
] for l in logits_topk
],
"input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))],
}
return Response(content=msgpack.packb(result), media_type="application/x-msgpack")


@app.post("/dictionaries/{dictionary_name}/features/{feature_index}/interpret")
def feature_interpretation(
dictionary_name: str,
feature_index: int,
type: str,
custom_interpretation: str | None = None,
dictionary_name: str,
feature_index: int,
type: str,
custom_interpretation: str | None = None,
):
model = get_model(dictionary_name)
if type == "custom":
Expand Down
Binary file modified ui/bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"react-plotly.js": "^2.6.0",
"react-router-dom": "^6.22.3",
"react-use": "^17.5.0",
"recharts": "^2.12.7",
"tailwind-merge": "^2.2.1",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",
Expand Down
9 changes: 9 additions & 0 deletions ui/src/components/app/navbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ export const AppNavbar = () => {
>
Dictionaries
</Link>
<Link
className={cn(
"transition-colors hover:text-foreground/80 text-foreground/60",
location.pathname === "/models" && "text-foreground"
)}
to="/models"
>
Models
</Link>
</div>
</div>
</nav>
Expand Down
216 changes: 216 additions & 0 deletions ui/src/components/model/model-card.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import { useState } from "react";
import { Button } from "../ui/button";
import { Card, CardContent, CardHeader, CardTitle } from "../ui/card";
import { Textarea } from "../ui/textarea";
import { useAsyncFn } from "react-use";
import camelcaseKeys from "camelcase-keys";
import { decode } from "@msgpack/msgpack";
import { ModelGeneration, ModelGenerationSchema } from "@/types/model";
import { SimpleSampleArea } from "../app/sample";
import { cn } from "@/lib/utils";
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Label, LabelList, ResponsiveContainer } from "recharts";
import { mergeUint8Arrays, zip } from "@/utils/array";
import { Input } from "../ui/input";

const ModelSampleArea = ({ sample }: { sample: ModelGeneration }) => {
const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState<number[]>([]);
const toggleSelectedTokenGroupIndex = (tokenGroupIndex: number) => {
setSelectedTokenGroupIndices((prev) =>
prev.includes(tokenGroupIndex) ? prev.filter((t) => t !== tokenGroupIndex) : [...prev, tokenGroupIndex]
);
};

const decoder = new TextDecoder("utf-8", { fatal: true });
const tokens = sample.context.map((token, i) => ({
token,
inputMask: sample.inputMask[i],
logits: zip(sample.logits[i], sample.logitsTokens[i]).map(([logits, logitsTokens]) => ({
logits,
logitsTokens,
})),
}));
const tokenGroups = tokens.reduce<[(typeof tokens)[], typeof tokens]>(
([groups, currentGroup], token) => {
const newGroup = [...currentGroup, token];
try {
decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token)));
return [[...groups, newGroup], []];
} catch {
return [groups, newGroup];
}
},
[[], []]
)[0];
const selectedTokenGroups = selectedTokenGroupIndices.map((i) => tokenGroups[i]);
const selectedTokens = selectedTokenGroups.flatMap((t) => t);

const data = selectedTokens.map((token) =>
Object.assign(
{},
...token.logits.map((logits, j) => ({
[`logits-${j}`]: logits.logits,
[`logits-token-${j}`]: logits.logitsTokens.reduce(
(acc, b) =>
b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`,
""
),
})),
{
name: token.token.reduce(
(acc, b) =>
b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`,
""
),
}
)
);

const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"];

return (
<div className="flex flex-col gap-4">
<SimpleSampleArea
sample={sample}
sampleName={`Generation`}
tokenGroupClassName={(_, tokenIndex) =>
cn(
"hover:shadow-lg hover:text-gray-600 cursor-pointer",
selectedTokenGroupIndices.some((t) => t === tokenIndex) && "bg-orange-500"
)
}
tokenGroupProps={(_, i) => ({
onClick: () => toggleSelectedTokenGroupIndex(i),
})}
/>

{selectedTokens.length > 0 && (
<ResponsiveContainer height={300}>
<BarChart data={data} margin={{ top: 50, right: 50, left: 50, bottom: 15 }}>
<CartesianGrid strokeDasharray="3 3" />
<XAxis dataKey="name">
<Label value="Tokens" offset={0} position="bottom" />
</XAxis>
<YAxis label={{ value: "Logits", angle: -90, position: "left", textAnchor: "middle" }} />
{selectedTokens[0].logits.map((_, i) => (
<Bar key={i} dataKey={`logits-${i}`} fill={colors[i]}>
<LabelList dataKey={`logits-token-${i}`} position="top" />
</Bar>
))}
</BarChart>
</ResponsiveContainer>
)}
</div>
);
};

const ModelCustomInputArea = () => {
const [customInput, setCustomInput] = useState<string>("");
const [maxNewTokens, setMaxNewTokens] = useState<number>(128);
const [topK, setTopK] = useState<number>(50);
const [topP, setTopP] = useState<number>(0.95);
const [logitTopK, setLogitTopK] = useState<number>(5);
const [sample, setSample] = useState<ModelGeneration | null>(null);
const [state, submit] = useAsyncFn(async () => {
if (!customInput) {
alert("Please enter your input.");
return;
}
const sample = await fetch(
`${import.meta.env.VITE_BACKEND_URL}/model/generate?input_text=${encodeURIComponent(
customInput
)}&max_new_tokens=${encodeURIComponent(maxNewTokens.toString())}&top_k=${encodeURIComponent(
topK.toString()
)}&top_p=${encodeURIComponent(topP.toString())}&return_logits_top_k=${encodeURIComponent(logitTopK.toString())}`,
{
method: "POST",
headers: {
Accept: "application/x-msgpack",
},
}
)
.then(async (res) => {
if (!res.ok) {
throw new Error(await res.text());
}
return res;
})
.then(async (res) => await res.arrayBuffer())
// eslint-disable-next-line @typescript-eslint/no-explicit-any
.then((res) => decode(new Uint8Array(res)) as any)
.then((res) =>
camelcaseKeys(res, {
deep: true,
stopPaths: ["context", "logits_tokens"],
})
)
.then((res) => ModelGenerationSchema.parse(res));
setSample(sample);
}, [customInput]);

return (
<div className="flex flex-col gap-4">
<p className="font-bold">Generation</p>
<div className="container grid grid-cols-4 justify-center items-center gap-4 px-20">
<span className="font-bold justify-self-end">Max new tokens:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={maxNewTokens.toString()}
onChange={(e) => setMaxNewTokens(parseInt(e.target.value))}
/>
<span className="font-bold justify-self-end">Top K:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={topK.toString()}
onChange={(e) => setTopK(parseInt(e.target.value))}
/>
<span className="font-bold justify-self-end">Top P:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={topP.toString()}
onChange={(e) => setTopP(parseFloat(e.target.value))}
/>
<span className="font-bold justify-self-end">Logit Top K:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={logitTopK.toString()}
onChange={(e) => setLogitTopK(parseInt(e.target.value))}
/>
</div>
<Textarea
placeholder="Type your custom input here."
value={customInput}
onChange={(e) => setCustomInput(e.target.value)}
/>
<Button onClick={submit} disabled={state.loading}>
Submit
</Button>
{state.error && <p className="text-red-500">{state.error.message}</p>}
{sample && <ModelSampleArea sample={sample} />}
</div>
);
};

export const ModelCard = () => {
return (
<Card className="container">
<CardHeader>
<CardTitle className="flex justify-between items-center text-xl">
<span>Model</span>
</CardTitle>
</CardHeader>
<CardContent>
<div className="flex flex-col gap-4">
<ModelCustomInputArea />
</div>
</CardContent>
</Card>
);
};
5 changes: 5 additions & 0 deletions ui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { FeaturesPage } from "@/routes/features/page";
import { RootPage } from "./routes/page";
import { AttentionHeadPage } from "./routes/attn-heads/page";
import { DictionaryPage } from "./routes/dictionaries/page";
import { ModelsPage } from "./routes/models/page";

const router = createBrowserRouter([
{
Expand All @@ -20,6 +21,10 @@ const router = createBrowserRouter([
path: "/attn-heads",
element: <AttentionHeadPage />,
},
{
path: "/models",
element: <ModelsPage />,
},
{
path: "/",
element: <RootPage />,
Expand Down
15 changes: 15 additions & 0 deletions ui/src/routes/models/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { AppNavbar } from "@/components/app/navbar";
import { ModelCard } from "@/components/model/model-card";

export const ModelsPage = () => {
return (
<div id="Top">
<AppNavbar />
<div className="pt-4 pb-20 px-20 flex flex-col items-center gap-12">
<div className="container flex gap-12">
<ModelCard />
</div>
</div>
</div>
);
};
10 changes: 10 additions & 0 deletions ui/src/types/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { z } from "zod";

export const ModelGenerationSchema = z.object({
context: z.array(z.instanceof(Uint8Array)),
inputMask: z.array(z.number()),
logits: z.array(z.array(z.number())),
logitsTokens: z.array(z.array(z.instanceof(Uint8Array))),
});

export type ModelGeneration = z.infer<typeof ModelGenerationSchema>;