Skip to content

Commit

Permalink
Merge pull request #45 from OpenMOSS/ui/model-page
Browse files Browse the repository at this point in the history
Model Page with Generation Section
  • Loading branch information
Hzfinfdu authored Aug 5, 2024
2 parents 95f03de + fa98aea commit ad0ec78
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 4 deletions.
35 changes: 31 additions & 4 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,40 @@ def dictionary_custom_input(dictionary_name: str, input_text: str):

return Response(content=msgpack.packb(sample), media_type="application/x-msgpack")

@app.post("/model/generate")
def model_generate(input_text: str, max_new_tokens: int = 128, top_k: int = 50, top_p: float = 0.95, return_logits_top_k: int = 5):
dictionaries = client.list_dictionaries(dictionary_series=dictionary_series)
assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred."
model = get_model(dictionaries[0])
with torch.no_grad():
input = model.to_tokens(input_text, prepend_bos=False)
output = model.generate(input, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p)
output = output.clone()
logits = model.forward(output)
logits_topk = [torch.topk(l, return_logits_top_k) for l in logits[0]]
result = {
"context": [
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(output[0])
],
"logits": [l.values.cpu().numpy().tolist() for l in logits_topk],
"logits_tokens": [
[
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(l.indices)
] for l in logits_topk
],
"input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))],
}
return Response(content=msgpack.packb(result), media_type="application/x-msgpack")


@app.post("/dictionaries/{dictionary_name}/features/{feature_index}/interpret")
def feature_interpretation(
dictionary_name: str,
feature_index: int,
type: str,
custom_interpretation: str | None = None,
dictionary_name: str,
feature_index: int,
type: str,
custom_interpretation: str | None = None,
):
model = get_model(dictionary_name)
if type == "custom":
Expand Down
Binary file modified ui/bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"react-plotly.js": "^2.6.0",
"react-router-dom": "^6.22.3",
"react-use": "^17.5.0",
"recharts": "^2.12.7",
"tailwind-merge": "^2.2.1",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",
Expand Down
9 changes: 9 additions & 0 deletions ui/src/components/app/navbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ export const AppNavbar = () => {
>
Dictionaries
</Link>
<Link
className={cn(
"transition-colors hover:text-foreground/80 text-foreground/60",
location.pathname === "/models" && "text-foreground"
)}
to="/models"
>
Models
</Link>
</div>
</div>
</nav>
Expand Down
216 changes: 216 additions & 0 deletions ui/src/components/model/model-card.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import { useState } from "react";
import { Button } from "../ui/button";
import { Card, CardContent, CardHeader, CardTitle } from "../ui/card";
import { Textarea } from "../ui/textarea";
import { useAsyncFn } from "react-use";
import camelcaseKeys from "camelcase-keys";
import { decode } from "@msgpack/msgpack";
import { ModelGeneration, ModelGenerationSchema } from "@/types/model";
import { SimpleSampleArea } from "../app/sample";
import { cn } from "@/lib/utils";
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Label, LabelList, ResponsiveContainer } from "recharts";
import { mergeUint8Arrays, zip } from "@/utils/array";
import { Input } from "../ui/input";

const ModelSampleArea = ({ sample }: { sample: ModelGeneration }) => {
const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState<number[]>([]);
const toggleSelectedTokenGroupIndex = (tokenGroupIndex: number) => {
setSelectedTokenGroupIndices((prev) =>
prev.includes(tokenGroupIndex) ? prev.filter((t) => t !== tokenGroupIndex) : [...prev, tokenGroupIndex]
);
};

const decoder = new TextDecoder("utf-8", { fatal: true });
const tokens = sample.context.map((token, i) => ({
token,
inputMask: sample.inputMask[i],
logits: zip(sample.logits[i], sample.logitsTokens[i]).map(([logits, logitsTokens]) => ({
logits,
logitsTokens,
})),
}));
const tokenGroups = tokens.reduce<[(typeof tokens)[], typeof tokens]>(
([groups, currentGroup], token) => {
const newGroup = [...currentGroup, token];
try {
decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token)));
return [[...groups, newGroup], []];
} catch {
return [groups, newGroup];
}
},
[[], []]
)[0];
const selectedTokenGroups = selectedTokenGroupIndices.map((i) => tokenGroups[i]);
const selectedTokens = selectedTokenGroups.flatMap((t) => t);

const data = selectedTokens.map((token) =>
Object.assign(
{},
...token.logits.map((logits, j) => ({
[`logits-${j}`]: logits.logits,
[`logits-token-${j}`]: logits.logitsTokens.reduce(
(acc, b) =>
b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`,
""
),
})),
{
name: token.token.reduce(
(acc, b) =>
b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`,
""
),
}
)
);

const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"];

return (
<div className="flex flex-col gap-4">
<SimpleSampleArea
sample={sample}
sampleName={`Generation`}
tokenGroupClassName={(_, tokenIndex) =>
cn(
"hover:shadow-lg hover:text-gray-600 cursor-pointer",
selectedTokenGroupIndices.some((t) => t === tokenIndex) && "bg-orange-500"
)
}
tokenGroupProps={(_, i) => ({
onClick: () => toggleSelectedTokenGroupIndex(i),
})}
/>

{selectedTokens.length > 0 && (
<ResponsiveContainer height={300}>
<BarChart data={data} margin={{ top: 50, right: 50, left: 50, bottom: 15 }}>
<CartesianGrid strokeDasharray="3 3" />
<XAxis dataKey="name">
<Label value="Tokens" offset={0} position="bottom" />
</XAxis>
<YAxis label={{ value: "Logits", angle: -90, position: "left", textAnchor: "middle" }} />
{selectedTokens[0].logits.map((_, i) => (
<Bar key={i} dataKey={`logits-${i}`} fill={colors[i]}>
<LabelList dataKey={`logits-token-${i}`} position="top" />
</Bar>
))}
</BarChart>
</ResponsiveContainer>
)}
</div>
);
};

const ModelCustomInputArea = () => {
const [customInput, setCustomInput] = useState<string>("");
const [maxNewTokens, setMaxNewTokens] = useState<number>(128);
const [topK, setTopK] = useState<number>(50);
const [topP, setTopP] = useState<number>(0.95);
const [logitTopK, setLogitTopK] = useState<number>(5);
const [sample, setSample] = useState<ModelGeneration | null>(null);
const [state, submit] = useAsyncFn(async () => {
if (!customInput) {
alert("Please enter your input.");
return;
}
const sample = await fetch(
`${import.meta.env.VITE_BACKEND_URL}/model/generate?input_text=${encodeURIComponent(
customInput
)}&max_new_tokens=${encodeURIComponent(maxNewTokens.toString())}&top_k=${encodeURIComponent(
topK.toString()
)}&top_p=${encodeURIComponent(topP.toString())}&return_logits_top_k=${encodeURIComponent(logitTopK.toString())}`,
{
method: "POST",
headers: {
Accept: "application/x-msgpack",
},
}
)
.then(async (res) => {
if (!res.ok) {
throw new Error(await res.text());
}
return res;
})
.then(async (res) => await res.arrayBuffer())
// eslint-disable-next-line @typescript-eslint/no-explicit-any
.then((res) => decode(new Uint8Array(res)) as any)
.then((res) =>
camelcaseKeys(res, {
deep: true,
stopPaths: ["context", "logits_tokens"],
})
)
.then((res) => ModelGenerationSchema.parse(res));
setSample(sample);
}, [customInput]);

return (
<div className="flex flex-col gap-4">
<p className="font-bold">Generation</p>
<div className="container grid grid-cols-4 justify-center items-center gap-4 px-20">
<span className="font-bold justify-self-end">Max new tokens:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={maxNewTokens.toString()}
onChange={(e) => setMaxNewTokens(parseInt(e.target.value))}
/>
<span className="font-bold justify-self-end">Top K:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={topK.toString()}
onChange={(e) => setTopK(parseInt(e.target.value))}
/>
<span className="font-bold justify-self-end">Top P:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={topP.toString()}
onChange={(e) => setTopP(parseFloat(e.target.value))}
/>
<span className="font-bold justify-self-end">Logit Top K:</span>
<Input
disabled={state.loading}
className="bg-white"
type="number"
value={logitTopK.toString()}
onChange={(e) => setLogitTopK(parseInt(e.target.value))}
/>
</div>
<Textarea
placeholder="Type your custom input here."
value={customInput}
onChange={(e) => setCustomInput(e.target.value)}
/>
<Button onClick={submit} disabled={state.loading}>
Submit
</Button>
{state.error && <p className="text-red-500">{state.error.message}</p>}
{sample && <ModelSampleArea sample={sample} />}
</div>
);
};

export const ModelCard = () => {
return (
<Card className="container">
<CardHeader>
<CardTitle className="flex justify-between items-center text-xl">
<span>Model</span>
</CardTitle>
</CardHeader>
<CardContent>
<div className="flex flex-col gap-4">
<ModelCustomInputArea />
</div>
</CardContent>
</Card>
);
};
5 changes: 5 additions & 0 deletions ui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { FeaturesPage } from "@/routes/features/page";
import { RootPage } from "./routes/page";
import { AttentionHeadPage } from "./routes/attn-heads/page";
import { DictionaryPage } from "./routes/dictionaries/page";
import { ModelsPage } from "./routes/models/page";

const router = createBrowserRouter([
{
Expand All @@ -20,6 +21,10 @@ const router = createBrowserRouter([
path: "/attn-heads",
element: <AttentionHeadPage />,
},
{
path: "/models",
element: <ModelsPage />,
},
{
path: "/",
element: <RootPage />,
Expand Down
15 changes: 15 additions & 0 deletions ui/src/routes/models/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { AppNavbar } from "@/components/app/navbar";
import { ModelCard } from "@/components/model/model-card";

export const ModelsPage = () => {
return (
<div id="Top">
<AppNavbar />
<div className="pt-4 pb-20 px-20 flex flex-col items-center gap-12">
<div className="container flex gap-12">
<ModelCard />
</div>
</div>
</div>
);
};
10 changes: 10 additions & 0 deletions ui/src/types/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { z } from "zod";

export const ModelGenerationSchema = z.object({
context: z.array(z.instanceof(Uint8Array)),
inputMask: z.array(z.number()),
logits: z.array(z.array(z.number())),
logitsTokens: z.array(z.array(z.instanceof(Uint8Array))),
});

export type ModelGeneration = z.infer<typeof ModelGenerationSchema>;

0 comments on commit ad0ec78

Please sign in to comment.