Skip to content

Commit

Permalink
Merge pull request #39 from OpenMOSS/ui
Browse files Browse the repository at this point in the history
Add the feature of side navigation and preview.
  • Loading branch information
dest1n1s authored Jul 29, 2024
2 parents fb6d687 + 21f80b7 commit 98e74ba
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 32 deletions.
60 changes: 60 additions & 0 deletions ui/src/components/app/section-navigator.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { cn } from "@/lib/utils";
import { Card, CardContent, CardHeader, CardTitle } from "../ui/card";
import { useEffect, useState } from "react";

export const SectionNavigator = ({ sections }: { sections: { title: string; id: string }[] }) => {
const [activeSection, setActiveSection] = useState<{ title: string; id: string } | null>(null);

const handleScroll = () => {
// Use reduce instead of find for obtaining the last section that is in view
const currentSection = sections.reduce((result: { title: string; id: string } | null, section) => {
const secElement = document.getElementById(section.id);
if (!secElement) return result;
const rect = secElement.getBoundingClientRect();
if (rect.top <= window.innerHeight / 2) {
return section;
}
return result;
}, null);

setActiveSection(currentSection);
};

useEffect(() => {
window.addEventListener("scroll", handleScroll);

// Run the handler to set the initial active section
handleScroll();

return () => {
window.removeEventListener("scroll", handleScroll);
};
});

return (
<Card className="py-4 sticky top-0 w-60 h-full bg-transparent">
<CardHeader className="py-0">
<CardTitle className="flex justify-between items-center text-xs p-2">
<span className="font-bold">CONTENTS</span>
</CardTitle>
</CardHeader>
<CardContent className="py-0">
<div className="flex flex-col">
<ul>
{sections.map((section) => (
<li key={section.id} className="relative">
<a
href={"#" + section.id}
className={cn("p-2 block text-neutral-700", activeSection === section && "text-[blue]")}
>
{section.title}
</a>
{activeSection === section && <div className="absolute -left-1.5 top-0 bottom-0 w-0.5 bg-[blue]"></div>}
</li>
))}
</ul>
</div>
</CardContent>
</Card>
);
};
24 changes: 13 additions & 11 deletions ui/src/components/dictionary/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName
...featureAct,
}))
)
.reduce((acc, featureAct) => {
// Group by featureActIndex
const key = featureAct.featureActIndex.toString();
if (acc[key]) {
acc[key].push(featureAct);
} else {
acc[key] = [featureAct];
}
return acc;
}, {} as Record<string, { token: Uint8Array; tokenIndex: number; featureAct: number; maxFeatureAct: number }[]>) ||
{}
.reduce(
(acc, featureAct) => {
// Group by featureActIndex
const key = featureAct.featureActIndex.toString();
if (acc[key]) {
acc[key].push(featureAct);
} else {
acc[key] = [featureAct];
}
return acc;
},
{} as Record<string, { token: Uint8Array; tokenIndex: number; featureAct: number; maxFeatureAct: number }[]>
) || {}
)
.sort(
// Sort by sum of featureAct
Expand Down
19 changes: 14 additions & 5 deletions ui/src/components/feature/feature-card.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
const [showCustomInput, setShowCustomInput] = useState<boolean>(false);

return (
<Card className="container">
<Card id="Interp." className="container">
<CardHeader>
<CardTitle className="flex justify-between items-center text-xl">
<span>
Expand All @@ -108,7 +108,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {

<FeatureInterpretation feature={feature} />

<div className="flex flex-col w-full gap-4">
<div id="Histogram" className="flex flex-col w-full gap-4">
<p className="font-bold">Activation Histogram</p>
<Plot
data={feature.featureActivationHistogram}
Expand All @@ -123,7 +123,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
</div>

{feature.logits && (
<div className="flex flex-col w-full gap-4">
<div id="Logits" className="flex flex-col w-full gap-4">
<p className="font-bold">Logits</p>
<div className="flex gap-4">
<div className="flex flex-col w-1/2 gap-4">
Expand Down Expand Up @@ -180,15 +180,24 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
</div>
)}

<div className="flex flex-col w-full gap-4">
<div id="Activation" className="flex flex-col w-full gap-4">
<Tabs defaultValue="top_activations">
<TabsList className="font-bold">
{feature.sampleGroups.map((sampleGroup) => (
{feature.sampleGroups.slice(0, feature.sampleGroups.length / 2).map((sampleGroup) => (
<TabsTrigger key={`tab-trigger-${sampleGroup.analysisName}`} value={sampleGroup.analysisName}>
{analysisNameMap(sampleGroup.analysisName)}
</TabsTrigger>
))}
</TabsList>
<TabsList className="font-bold">
{feature.sampleGroups
.slice(feature.sampleGroups.length / 2, feature.sampleGroups.length)
.map((sampleGroup) => (
<TabsTrigger key={`tab-trigger-${sampleGroup.analysisName}`} value={sampleGroup.analysisName}>
{analysisNameMap(sampleGroup.analysisName)}
</TabsTrigger>
))}
</TabsList>
{feature.sampleGroups.map((sampleGroup) => (
<TabsContent
key={`tab-content-${sampleGroup.analysisName}`}
Expand Down
77 changes: 64 additions & 13 deletions ui/src/components/feature/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { SuperToken } from "./token";
import { mergeUint8Arrays } from "@/utils/array";
import { useState } from "react";
import { AppPagination } from "../ui/pagination";
import { Accordion, AccordionTrigger, AccordionContent, AccordionItem } from "../ui/accordion";

export const FeatureSampleGroup = ({
feature,
Expand All @@ -12,16 +13,16 @@ export const FeatureSampleGroup = ({
sampleGroup: Feature["sampleGroups"][0];
}) => {
const [page, setPage] = useState<number>(1);
const maxPage = Math.ceil(sampleGroup.samples.length / 5);
const maxPage = Math.ceil(sampleGroup.samples.length / 10);

return (
<div className="flex flex-col gap-4 mt-4">
<p className="font-bold">Max Activation: {Math.max(...sampleGroup.samples[0].featureActs).toFixed(3)}</p>
{sampleGroup.samples.slice((page - 1) * 5, page * 5).map((sample, i) => (
{sampleGroup.samples.slice((page - 1) * 10, page * 10).map((sample, i) => (
<FeatureActivationSample
key={i}
sample={sample}
sampleName={`Sample ${(page - 1) * 5 + i + 1}`}
sampleName={`Sample ${(page - 1) * 10 + i + 1}`}
maxFeatureAct={feature.maxFeatureAct}
/>
))}
Expand Down Expand Up @@ -69,18 +70,68 @@ export const FeatureActivationSample = ({ sample, sampleName, maxFeatureAct }: F
[0]
);

const tokensList = tokens.map((t) => t.featureAct);
const startTrigger = Math.max(tokensList.indexOf(Math.max(...tokensList)) - 100, 0);
const endTrigger = Math.min(tokensList.indexOf(Math.max(...tokensList)) + 10, sample.context.length);
const tokensTrigger = sample.context.slice(startTrigger, endTrigger).map((token, i) => ({
token,
featureAct: sample.featureActs[startTrigger + i],
}));

const [tokenGroupsTrigger, __] = tokensTrigger.reduce<[Token[][], Token[]]>(
([groups, currentGroup], token) => {
const newGroup = [...currentGroup, token];
try {
decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token)));
return [[...groups, newGroup], []];
} catch {
return [groups, newGroup];
}
},
[[], []]
);

const tokenGroupPositionsTrigger = tokenGroupsTrigger.reduce<number[]>(
(acc, tokenGroup) => {
const tokenCount = tokenGroup.length;
return [...acc, acc[acc.length - 1] + tokenCount];
},
[0]
);

return (
<div>
{sampleName && <span className="text-gray-700 font-bold">{sampleName}: </span>}
{tokenGroups.map((tokens, i) => (
<SuperToken
key={`group-${i}`}
tokens={tokens}
position={tokenGroupPositions[i]}
maxFeatureAct={maxFeatureAct}
sampleMaxFeatureAct={sampleMaxFeatureAct}
/>
))}
<Accordion type="single" collapsible>
<AccordionItem value={sampleMaxFeatureAct.toString()}>
<AccordionTrigger>
<div className="block text-left">
{sampleName && <span className="text-gray-700 font-bold whitespace-pre">{sampleName}: </span>}
{startTrigger != 0 && <span className="text-sky-300">...</span>}
{tokenGroupsTrigger.map((tokens, i) => (
<SuperToken
key={`trigger-group-${i}`}
tokens={tokens}
position={tokenGroupPositionsTrigger[i]}
maxFeatureAct={maxFeatureAct}
sampleMaxFeatureAct={sampleMaxFeatureAct}
/>
))}
{endTrigger != 0 && <span className="text-sky-300"> ...</span>}
</div>
</AccordionTrigger>
<AccordionContent>
{tokenGroups.map((tokens, i) => (
<SuperToken
key={`group-${i}`}
tokens={tokens}
position={tokenGroupPositions[i]}
maxFeatureAct={maxFeatureAct}
sampleMaxFeatureAct={sampleMaxFeatureAct}
/>
))}
</AccordionContent>
</AccordionItem>
</Accordion>
</div>
);
};
2 changes: 1 addition & 1 deletion ui/src/components/ui/accordion.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const AccordionTrigger = React.forwardRef<
<AccordionPrimitive.Trigger
ref={ref}
className={cn(
"flex flex-1 items-center justify-between py-4 font-medium transition-all hover:underline [&[data-state=open]>svg]:rotate-180",
"flex flex-1 items-center justify-between py-4 font-medium transition-all [&[data-state=open]>svg]:rotate-180",
className
)}
{...props}
Expand Down
4 changes: 4 additions & 0 deletions ui/src/globals.css
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@
@apply bg-background text-foreground;
}
}

html {
scroll-behavior: smooth;
}
26 changes: 24 additions & 2 deletions ui/src/routes/features/page.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AppNavbar } from "@/components/app/navbar";
import { FeatureCard } from "@/components/feature/feature-card";
import { SectionNavigator } from "@/components/app/section-navigator";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
Expand Down Expand Up @@ -87,8 +88,23 @@ export const FeaturesPage = () => {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [dictionariesState.value]);

const sections = [
{
title: "Histogram",
id: "Histogram",
},
{
title: "Logits",
id: "Logits",
},
{
title: "Top Activation",
id: "Activation",
},
].filter((section) => (featureState.value && featureState.value.logits != null) || section.id !== "Logits");

return (
<div>
<div id="Top">
<AppNavbar />
<div className="pt-4 pb-20 px-20 flex flex-col items-center gap-12">
<div className="container grid grid-cols-[auto_600px_auto_auto] justify-center items-center gap-4">
Expand Down Expand Up @@ -142,14 +158,20 @@ export const FeaturesPage = () => {
Show Random Feature
</Button>
</div>

{featureState.loading && !loadingRandomFeature && (
<div>
Loading Feature <span className="font-bold">#{featureIndex}</span>...
</div>
)}
{featureState.loading && loadingRandomFeature && <div>Loading Random Living Feature...</div>}
{featureState.error && <div className="text-red-500 font-bold">Error: {featureState.error.message}</div>}
{!featureState.loading && featureState.value && <FeatureCard feature={featureState.value} />}
{!featureState.loading && featureState.value && (
<div className="flex gap-12">
<FeatureCard feature={featureState.value} />
<SectionNavigator sections={sections} />
</div>
)}
</div>
</div>
);
Expand Down

0 comments on commit 98e74ba

Please sign in to comment.