Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the feature of side navigation and preview. #39

Merged
merged 12 commits into from
Jul 29, 2024
60 changes: 60 additions & 0 deletions ui/src/components/app/section-navigator.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { cn } from "@/lib/utils";
import { Card, CardContent, CardHeader, CardTitle } from "../ui/card";
import { useEffect, useState } from "react";

export const SectionNavigator = ({ sections }: { sections: { title: string; id: string }[] }) => {
const [activeSection, setActiveSection] = useState<{ title: string; id: string } | null>(null);

const handleScroll = () => {
// Use reduce instead of find for obtaining the last section that is in view
const currentSection = sections.reduce((result: { title: string; id: string } | null, section) => {
const secElement = document.getElementById(section.id);
if (!secElement) return result;
const rect = secElement.getBoundingClientRect();
if (rect.top <= window.innerHeight / 2) {
return section;
}
return result;
}, null);

setActiveSection(currentSection);
};

useEffect(() => {
window.addEventListener("scroll", handleScroll);

// Run the handler to set the initial active section
handleScroll();

return () => {
window.removeEventListener("scroll", handleScroll);
};
});

return (
<Card className="py-4 sticky top-0 w-60 h-full bg-transparent">
<CardHeader className="py-0">
<CardTitle className="flex justify-between items-center text-xs p-2">
<span className="font-bold">CONTENTS</span>
</CardTitle>
</CardHeader>
<CardContent className="py-0">
<div className="flex flex-col">
<ul>
{sections.map((section) => (
<li key={section.id} className="relative">
<a
href={"#" + section.id}
className={cn("p-2 block text-neutral-700", activeSection === section && "text-[blue]")}
>
{section.title}
</a>
{activeSection === section && <div className="absolute -left-1.5 top-0 bottom-0 w-0.5 bg-[blue]"></div>}
</li>
))}
</ul>
</div>
</CardContent>
</Card>
);
};
24 changes: 13 additions & 11 deletions ui/src/components/dictionary/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName
...featureAct,
}))
)
.reduce((acc, featureAct) => {
// Group by featureActIndex
const key = featureAct.featureActIndex.toString();
if (acc[key]) {
acc[key].push(featureAct);
} else {
acc[key] = [featureAct];
}
return acc;
}, {} as Record<string, { token: Uint8Array; tokenIndex: number; featureAct: number; maxFeatureAct: number }[]>) ||
{}
.reduce(
(acc, featureAct) => {
// Group by featureActIndex
const key = featureAct.featureActIndex.toString();
if (acc[key]) {
acc[key].push(featureAct);
} else {
acc[key] = [featureAct];
}
return acc;
},
{} as Record<string, { token: Uint8Array; tokenIndex: number; featureAct: number; maxFeatureAct: number }[]>
) || {}
)
.sort(
// Sort by sum of featureAct
Expand Down
19 changes: 14 additions & 5 deletions ui/src/components/feature/feature-card.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
const [showCustomInput, setShowCustomInput] = useState<boolean>(false);

return (
<Card className="container">
<Card id="Interp." className="container">
<CardHeader>
<CardTitle className="flex justify-between items-center text-xl">
<span>
Expand All @@ -108,7 +108,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {

<FeatureInterpretation feature={feature} />

<div className="flex flex-col w-full gap-4">
<div id="Histogram" className="flex flex-col w-full gap-4">
<p className="font-bold">Activation Histogram</p>
<Plot
data={feature.featureActivationHistogram}
Expand All @@ -123,7 +123,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
</div>

{feature.logits && (
<div className="flex flex-col w-full gap-4">
<div id="Logits" className="flex flex-col w-full gap-4">
<p className="font-bold">Logits</p>
<div className="flex gap-4">
<div className="flex flex-col w-1/2 gap-4">
Expand Down Expand Up @@ -180,15 +180,24 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
</div>
)}

<div className="flex flex-col w-full gap-4">
<div id="Activation" className="flex flex-col w-full gap-4">
<Tabs defaultValue="top_activations">
<TabsList className="font-bold">
{feature.sampleGroups.map((sampleGroup) => (
{feature.sampleGroups.slice(0, feature.sampleGroups.length / 2).map((sampleGroup) => (
<TabsTrigger key={`tab-trigger-${sampleGroup.analysisName}`} value={sampleGroup.analysisName}>
{analysisNameMap(sampleGroup.analysisName)}
</TabsTrigger>
))}
</TabsList>
<TabsList className="font-bold">
{feature.sampleGroups
.slice(feature.sampleGroups.length / 2, feature.sampleGroups.length)
.map((sampleGroup) => (
<TabsTrigger key={`tab-trigger-${sampleGroup.analysisName}`} value={sampleGroup.analysisName}>
{analysisNameMap(sampleGroup.analysisName)}
</TabsTrigger>
))}
</TabsList>
{feature.sampleGroups.map((sampleGroup) => (
<TabsContent
key={`tab-content-${sampleGroup.analysisName}`}
Expand Down
77 changes: 64 additions & 13 deletions ui/src/components/feature/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { SuperToken } from "./token";
import { mergeUint8Arrays } from "@/utils/array";
import { useState } from "react";
import { AppPagination } from "../ui/pagination";
import { Accordion, AccordionTrigger, AccordionContent, AccordionItem } from "../ui/accordion";

export const FeatureSampleGroup = ({
feature,
Expand All @@ -12,16 +13,16 @@ export const FeatureSampleGroup = ({
sampleGroup: Feature["sampleGroups"][0];
}) => {
const [page, setPage] = useState<number>(1);
const maxPage = Math.ceil(sampleGroup.samples.length / 5);
const maxPage = Math.ceil(sampleGroup.samples.length / 10);

return (
<div className="flex flex-col gap-4 mt-4">
<p className="font-bold">Max Activation: {Math.max(...sampleGroup.samples[0].featureActs).toFixed(3)}</p>
{sampleGroup.samples.slice((page - 1) * 5, page * 5).map((sample, i) => (
{sampleGroup.samples.slice((page - 1) * 10, page * 10).map((sample, i) => (
<FeatureActivationSample
key={i}
sample={sample}
sampleName={`Sample ${(page - 1) * 5 + i + 1}`}
sampleName={`Sample ${(page - 1) * 10 + i + 1}`}
maxFeatureAct={feature.maxFeatureAct}
/>
))}
Expand Down Expand Up @@ -69,18 +70,68 @@ export const FeatureActivationSample = ({ sample, sampleName, maxFeatureAct }: F
[0]
);

const tokensList = tokens.map((t) => t.featureAct);
const startTrigger = Math.max(tokensList.indexOf(Math.max(...tokensList)) - 100, 0);
const endTrigger = Math.min(tokensList.indexOf(Math.max(...tokensList)) + 10, sample.context.length);
const tokensTrigger = sample.context.slice(startTrigger, endTrigger).map((token, i) => ({
token,
featureAct: sample.featureActs[startTrigger + i],
}));

const [tokenGroupsTrigger, __] = tokensTrigger.reduce<[Token[][], Token[]]>(
([groups, currentGroup], token) => {
const newGroup = [...currentGroup, token];
try {
decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token)));
return [[...groups, newGroup], []];
} catch {
return [groups, newGroup];
}
},
[[], []]
);

const tokenGroupPositionsTrigger = tokenGroupsTrigger.reduce<number[]>(
(acc, tokenGroup) => {
const tokenCount = tokenGroup.length;
return [...acc, acc[acc.length - 1] + tokenCount];
},
[0]
);

return (
<div>
{sampleName && <span className="text-gray-700 font-bold">{sampleName}: </span>}
{tokenGroups.map((tokens, i) => (
<SuperToken
key={`group-${i}`}
tokens={tokens}
position={tokenGroupPositions[i]}
maxFeatureAct={maxFeatureAct}
sampleMaxFeatureAct={sampleMaxFeatureAct}
/>
))}
<Accordion type="single" collapsible>
<AccordionItem value={sampleMaxFeatureAct.toString()}>
<AccordionTrigger>
<div className="block text-left">
{sampleName && <span className="text-gray-700 font-bold whitespace-pre">{sampleName}: </span>}
{startTrigger != 0 && <span className="text-sky-300">...</span>}
{tokenGroupsTrigger.map((tokens, i) => (
<SuperToken
key={`trigger-group-${i}`}
tokens={tokens}
position={tokenGroupPositionsTrigger[i]}
maxFeatureAct={maxFeatureAct}
sampleMaxFeatureAct={sampleMaxFeatureAct}
/>
))}
{endTrigger != 0 && <span className="text-sky-300"> ...</span>}
</div>
</AccordionTrigger>
<AccordionContent>
{tokenGroups.map((tokens, i) => (
<SuperToken
key={`group-${i}`}
tokens={tokens}
position={tokenGroupPositions[i]}
maxFeatureAct={maxFeatureAct}
sampleMaxFeatureAct={sampleMaxFeatureAct}
/>
))}
</AccordionContent>
</AccordionItem>
</Accordion>
</div>
);
};
2 changes: 1 addition & 1 deletion ui/src/components/ui/accordion.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const AccordionTrigger = React.forwardRef<
<AccordionPrimitive.Trigger
ref={ref}
className={cn(
"flex flex-1 items-center justify-between py-4 font-medium transition-all hover:underline [&[data-state=open]>svg]:rotate-180",
"flex flex-1 items-center justify-between py-4 font-medium transition-all [&[data-state=open]>svg]:rotate-180",
className
)}
{...props}
Expand Down
4 changes: 4 additions & 0 deletions ui/src/globals.css
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@
@apply bg-background text-foreground;
}
}

html {
scroll-behavior: smooth;
}
26 changes: 24 additions & 2 deletions ui/src/routes/features/page.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AppNavbar } from "@/components/app/navbar";
import { FeatureCard } from "@/components/feature/feature-card";
import { SectionNavigator } from "@/components/app/section-navigator";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
Expand Down Expand Up @@ -87,8 +88,23 @@ export const FeaturesPage = () => {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [dictionariesState.value]);

const sections = [
{
title: "Histogram",
id: "Histogram",
},
{
title: "Logits",
id: "Logits",
},
{
title: "Top Activation",
id: "Activation",
},
].filter((section) => (featureState.value && featureState.value.logits != null) || section.id !== "Logits");

return (
<div>
<div id="Top">
<AppNavbar />
<div className="pt-4 pb-20 px-20 flex flex-col items-center gap-12">
<div className="container grid grid-cols-[auto_600px_auto_auto] justify-center items-center gap-4">
Expand Down Expand Up @@ -142,14 +158,20 @@ export const FeaturesPage = () => {
Show Random Feature
</Button>
</div>

{featureState.loading && !loadingRandomFeature && (
<div>
Loading Feature <span className="font-bold">#{featureIndex}</span>...
</div>
)}
{featureState.loading && loadingRandomFeature && <div>Loading Random Living Feature...</div>}
{featureState.error && <div className="text-red-500 font-bold">Error: {featureState.error.message}</div>}
{!featureState.loading && featureState.value && <FeatureCard feature={featureState.value} />}
{!featureState.loading && featureState.value && (
<div className="flex gap-12">
<FeatureCard feature={featureState.value} />
<SectionNavigator sections={sections} />
</div>
)}
</div>
</div>
);
Expand Down