diff --git a/server/app.py b/server/app.py index cb29624..577afc5 100644 --- a/server/app.py +++ b/server/app.py @@ -1,9 +1,11 @@ import os -from typing import Any, cast +from typing import Annotated, Any, Literal, cast import numpy as np +from pydantic import BaseModel import torch +from transformer_lens.hook_points import HookPoint from transformers import AutoModelForCausalLM, AutoTokenizer from transformer_lens import HookedTransformer, HookedTransformerConfig @@ -13,7 +15,7 @@ import msgpack -from fastapi import FastAPI, Response +from fastapi import FastAPI, Query, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -21,6 +23,7 @@ import plotly.graph_objects as go from lm_saes.analysis.auto_interp import check_description, generate_description +from lm_saes.circuit.context import apply_sae from lm_saes.config import AutoInterpConfig, LanguageModelConfig, SAEConfig from lm_saes.database import MongoClient from lm_saes.sae import SparseAutoEncoder @@ -300,13 +303,102 @@ def dictionary_custom_input(dictionary_name: str, input_text: str): return Response(content=msgpack.packb(sample), media_type="application/x-msgpack") +class SteeringConfig(BaseModel): + sae: str + feature_index: int + steering_type: Literal["times", "add", "set", "ablate"] + steering_value: float | None = None + +class ModelGenerateRequest(BaseModel): + input_text: str + max_new_tokens: int = 128 + top_k: int = 50 + top_p: float = 0.95 + return_logits_top_k: int = 5 + saes: list[str] = [] + steerings: list[SteeringConfig] = [] + +@app.post("/model/generate") +def model_generate(request: ModelGenerateRequest): + dictionaries = client.list_dictionaries(dictionary_series=dictionary_series) + assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred." + model = get_model(dictionaries[0]) + saes = [(get_sae(name), name) for name in request.saes] + max_feature_acts = { + name: client.get_max_feature_acts(name, dictionary_series=dictionary_series) + for _, name in saes + } + assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found" + + def generate_steering_hook(steering: SteeringConfig): + def steering_hook(tensor: torch.Tensor, hook: HookPoint): + assert len(tensor.shape) == 3 + tensor = tensor.clone() + if steering.steering_type == "times": + assert steering.steering_value is not None + tensor[:, :, steering.feature_index] *= steering.steering_value + elif steering.steering_type == "ablate": + tensor[:, :, steering.feature_index] = 0 + elif steering.steering_type == "add": + assert steering.steering_value is not None + tensor[:, :, steering.feature_index] += steering.steering_value + elif steering.steering_type == "set": + assert steering.steering_value is not None + tensor[:, :, steering.feature_index] = steering.steering_value + return tensor + sae = get_sae(steering.sae) + return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook + + steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings] + + with torch.no_grad(): + with apply_sae(model, [sae for sae, _ in saes]): + with model.hooks(steerings_hooks): + input = model.to_tokens(request.input_text, prepend_bos=False) + output = model.generate(input, max_new_tokens=request.max_new_tokens, top_k=request.top_k, top_p=request.top_p) + output = output.clone() + logits, cache = model.run_with_cache(output, names_filter=[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts" for sae, _ in saes]) + logits_topk = [torch.topk(l, request.return_logits_top_k) for l in logits[0]] + result = { + "context": [ + bytearray([byte_decoder[c] for c in t]) + for t in model.tokenizer.convert_ids_to_tokens(output[0]) + ], + "logits": [l.values.cpu().numpy().tolist() for l in logits_topk], + "logits_tokens": [ + [ + bytearray([byte_decoder[c] for c in t]) + for t in model.tokenizer.convert_ids_to_tokens(l.indices) + ] for l in logits_topk + ], + "input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))], + "sae_info": [ + { + "name": name, + "feature_acts_indices": [ + cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0].cpu().numpy().tolist() + for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0]) + ], + "feature_acts": [ + cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i][cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0]].cpu().numpy().tolist() + for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0]) + ], + "max_feature_acts": [ + [max_feature_acts[name][j] for j in cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0].cpu().numpy().tolist()] + for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0]) + ] + } for sae, name in saes + ], + } + return Response(content=msgpack.packb(result), media_type="application/x-msgpack") + @app.post("/dictionaries/{dictionary_name}/features/{feature_index}/interpret") def feature_interpretation( - dictionary_name: str, - feature_index: int, - type: str, - custom_interpretation: str | None = None, + dictionary_name: str, + feature_index: int, + type: str, + custom_interpretation: str | None = None, ): model = get_model(dictionary_name) if type == "custom": diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 606d695..62fdded 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -105,7 +105,7 @@ class TextDatasetConfig(RunnerConfig): context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) - prepend_bos: List[bool] = field(default_factory=lambda: [False]) + prepend_bos: List[bool] = field(default_factory=lambda: [True]) def __post_init__(self): super().__post_init__() @@ -118,6 +118,9 @@ def __post_init__(self): if isinstance(self.prepend_bos, bool): self.prepend_bos = [self.prepend_bos] + if False in self.prepend_bos: + print('Warning: prepend_bos is set to False for some datasets. This setting might not be suitable for most modern models.') + self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs] assert len(self.sample_probs) == len( @@ -393,6 +396,8 @@ def __post_init__(self): assert 0 <= self.lr_cool_down_steps <= 1.0 self.lr_cool_down_steps = int(self.lr_cool_down_steps * total_training_steps) print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}") + if self.finetuning: + assert self.sae.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning." @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 088099b..5095d84 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -48,15 +48,6 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - if is_master(): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_path)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_path)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() - hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -87,12 +78,29 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - model.offload_params_after( - cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device) - ) + model.offload_params_after(cfg.act_store.hook_points[-1], torch.tensor([[0]], device=cfg.lm.device)) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + if not cfg.finetuning and ( + cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None + or cfg.sae.init_decoder_norm is None + ): + sae = SparseAutoEncoder.from_initialization_searching( + activation_store=activation_store, + cfg=cfg, + ) + else: + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + + if is_master(): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { **asdict(cfg), diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 1630057..1f011ab 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -140,12 +140,10 @@ def train_base_parameters(self): p.requires_grad_(True) def train_finetune_for_suppression_parameters(self): - """Set the parameters to be trained for feature suppression.""" + """Set the parameters to be trained against feature suppression.""" + + finetune_for_suppression_parameters = [self.decoder.weight] - finetune_for_suppression_parameters = [ - self.feature_act_scale, - self.decoder.weight, - ] if self.cfg.use_decoder_bias: finetune_for_suppression_parameters.append(self.decoder.bias) for p in self.parameters(): diff --git a/src/lm_saes/utils/huggingface.py b/src/lm_saes/utils/huggingface.py index d51bae0..7807d4f 100644 --- a/src/lm_saes/utils/huggingface.py +++ b/src/lm_saes/utils/huggingface.py @@ -4,6 +4,7 @@ import os import shutil from huggingface_hub import create_repo, upload_folder, snapshot_download +from lm_saes.utils.misc import print_once def upload_pretrained_sae_to_hf(sae_path: str, repo_id: str, private: bool = False): @@ -54,6 +55,7 @@ def parse_pretrained_name_or_path(pretrained_name_or_path: str): if os.path.exists(pretrained_name_or_path): return pretrained_name_or_path else: + print_once(f'Local path `{pretrained_name_or_path}` not found. Downloading from huggingface model hub.') repo_id = "/".join(pretrained_name_or_path.split("/")[:2]) hook_point = "/".join(pretrained_name_or_path.split("/")[2:]) return download_pretrained_sae_from_hf(repo_id, hook_point) \ No newline at end of file diff --git a/ui/bun.lockb b/ui/bun.lockb index 6ee4bfe..3bd98f1 100755 Binary files a/ui/bun.lockb and b/ui/bun.lockb differ diff --git a/ui/package.json b/ui/package.json index 93486e7..717bf84 100644 --- a/ui/package.json +++ b/ui/package.json @@ -12,18 +12,23 @@ "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", "@radix-ui/react-accordion": "^1.1.2", + "@radix-ui/react-dialog": "^1.1.1", "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-hover-card": "^1.0.7", "@radix-ui/react-label": "^2.0.2", + "@radix-ui/react-popover": "^1.1.1", "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", "@radix-ui/react-slot": "^1.0.2", + "@radix-ui/react-switch": "^1.1.0", "@radix-ui/react-tabs": "^1.0.4", "@radix-ui/react-toggle": "^1.0.3", + "@radix-ui/react-tooltip": "^1.1.2", "@tanstack/react-table": "^8.15.3", "camelcase-keys": "^9.1.3", "class-variance-authority": "^0.7.0", "clsx": "^2.1.0", + "cmdk": "1.0.0", "lucide-react": "^0.358.0", "plotly.js": "^2.30.1", "react": "^18.2.0", @@ -31,6 +36,8 @@ "react-plotly.js": "^2.6.0", "react-router-dom": "^6.22.3", "react-use": "^17.5.0", + "recharts": "^2.12.7", + "snakecase-keys": "^8.0.1", "tailwind-merge": "^2.2.1", "tailwindcss-animate": "^1.0.7", "zod": "^3.22.4", diff --git a/ui/src/components/app/navbar.tsx b/ui/src/components/app/navbar.tsx index 34cf336..6279562 100644 --- a/ui/src/components/app/navbar.tsx +++ b/ui/src/components/app/navbar.tsx @@ -28,6 +28,15 @@ export const AppNavbar = () => { > Dictionaries + + Models + diff --git a/ui/src/components/app/sample.tsx b/ui/src/components/app/sample.tsx index e96b40b..5d90ff1 100644 --- a/ui/src/components/app/sample.tsx +++ b/ui/src/components/app/sample.tsx @@ -1,61 +1,55 @@ +import { useState } from "react"; +import { TokenGroup } from "./token"; import { cn } from "@/lib/utils"; -import { mergeUint8Arrays } from "@/utils/array"; -export type SimpleSampleAreaProps = { - sample: { context: Uint8Array[] }; - sampleName: string; - tokenGroupClassName?: (tokens: { token: Uint8Array }[], i: number) => string; - tokenGroupProps?: (tokens: { token: Uint8Array }[], i: number) => React.HTMLProps; +export type SampleProps = { + sampleName?: string; + tokenGroups: T[][]; + tokenGroupClassName?: (tokenGroup: T[], i: number) => string; + tokenGroupProps?: (tokenGroup: T[], i: number) => React.HTMLProps; + tokenInfoContent?: (tokenGroup: T[], i: number) => (token: T, i: number) => React.ReactNode; + tokenGroupInfoContent?: (tokenGroup: T[], i: number) => React.ReactNode; + customTokenGroup?: (tokens: T[], i: number) => React.ReactNode; + foldedStart?: number; }; -export const SimpleSampleArea = ({ - sample, +export const Sample = ({ sampleName, + tokenGroups, tokenGroupClassName, tokenGroupProps, -}: SimpleSampleAreaProps) => { - const decoder = new TextDecoder("utf-8", { fatal: true }); - - const start = Math.max(0); - const end = Math.min(sample.context.length); - const tokens = sample.context.slice(start, end).map((token) => ({ - token, - })); - - type Token = { token: Uint8Array }; - - const [tokenGroups, _] = tokens.reduce<[Token[][], Token[]]>( - ([groups, currentGroup], token) => { - const newGroup = [...currentGroup, token]; - try { - decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token))); - return [[...groups, newGroup], []]; - } catch { - return [groups, newGroup]; - } - }, - [[], []] - ); + tokenInfoContent, + tokenGroupInfoContent, + customTokenGroup, + foldedStart, +}: SampleProps) => { + const [folded, setFolded] = useState(true); return ( -
- {sampleName && {sampleName}: } - {tokenGroups.map((tokens, i) => ( - +
setFolded(!folded) : undefined} + > + {sampleName && {sampleName}: } + {folded && !!foldedStart && ...} + {tokenGroups + .slice((folded && foldedStart) || 0) + .map((tokens, i) => + customTokenGroup ? ( + customTokenGroup(tokens, i) + ) : ( + + ) )} - key={i} - {...tokenGroupProps?.(tokens, i)} - > - {decoder - .decode(mergeUint8Arrays(tokens.map((t) => t.token))) - .replace("\n", "⏎") - .replace("\t", "⇥") - .replace("\r", "↵")} - - ))} +
); }; diff --git a/ui/src/components/app/token.tsx b/ui/src/components/app/token.tsx new file mode 100644 index 0000000..f6f84b7 --- /dev/null +++ b/ui/src/components/app/token.tsx @@ -0,0 +1,82 @@ +import { cn } from "@/lib/utils"; +import { mergeUint8Arrays } from "@/utils/array"; +import { HoverCard, HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; +import { Fragment } from "react/jsx-runtime"; +import { Separator } from "../ui/separator"; + +export type PlainTokenGroupProps = { + tokenGroup: T[]; + tokenGroupClassName?: string; + tokenGroupProps?: React.HTMLProps; +}; + +export const PlainTokenGroup = ({ + tokenGroup, + tokenGroupClassName, + tokenGroupProps, +}: PlainTokenGroupProps) => { + const decoder = new TextDecoder("utf-8", { fatal: true }); + + return ( + + {decoder + .decode(mergeUint8Arrays(tokenGroup.map((t) => t.token))) + .replace("\n", "⏎") + .replace("\t", "⇥") + .replace("\r", "↵")} + + ); +}; + +export type TokenGroupProps = PlainTokenGroupProps & { + tokenInfoContent?: (token: T, i: number) => React.ReactNode; + tokenGroupInfoContent?: React.ReactNode; +}; + +export const TokenGroup = ({ + tokenGroup, + tokenGroupClassName, + tokenGroupProps, + tokenInfoContent, + tokenGroupInfoContent, +}: TokenGroupProps) => { + if (!tokenInfoContent && !tokenGroupInfoContent) { + return ( + + ); + } + + return ( + + + + + + {tokenGroupInfoContent ? ( + {tokenGroupInfoContent} + ) : ( + tokenGroup.map((token, i) => ( + + {tokenInfoContent?.(token, i)} + {i < tokenGroup.length - 1 && } + + )) + )} + + + ); +}; diff --git a/ui/src/components/dictionary/dictionary-card.tsx b/ui/src/components/dictionary/dictionary-card.tsx index 654e5c9..f9898f2 100644 --- a/ui/src/components/dictionary/dictionary-card.tsx +++ b/ui/src/components/dictionary/dictionary-card.tsx @@ -1,17 +1,17 @@ import { useState } from "react"; import { Button } from "../ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; -import { Dictionary, DictionarySample, DictionarySampleSchema } from "@/types/dictionary"; +import { Dictionary, DictionarySampleCompact, DictionarySampleCompactSchema } from "@/types/dictionary"; import Plot from "react-plotly.js"; import { useAsyncFn } from "react-use"; import { decode } from "@msgpack/msgpack"; import camelcaseKeys from "camelcase-keys"; import { Textarea } from "../ui/textarea"; -import { DictionarySampleArea } from "./sample"; +import { DictionarySample } from "./sample"; const DictionaryCustomInputArea = ({ dictionary }: { dictionary: Dictionary }) => { const [customInput, setCustomInput] = useState(""); - const [samples, setSamples] = useState([]); + const [samples, setSamples] = useState([]); const [state, submit] = useAsyncFn(async () => { if (!customInput) { alert("Please enter your input."); @@ -43,7 +43,7 @@ const DictionaryCustomInputArea = ({ dictionary }: { dictionary: Dictionary }) = stopPaths: ["context"], }) ) - .then((res) => DictionarySampleSchema.parse(res)); + .then((res) => DictionarySampleCompactSchema.parse(res)); setSamples((prev) => [...prev, sample]); }, [customInput]); @@ -60,11 +60,7 @@ const DictionaryCustomInputArea = ({ dictionary }: { dictionary: Dictionary }) = {state.error &&

{state.error.message}

} {samples.length > 0 && ( - + )} ); diff --git a/ui/src/components/dictionary/sample.tsx b/ui/src/components/dictionary/sample.tsx index a5f5f23..f04f4d1 100644 --- a/ui/src/components/dictionary/sample.tsx +++ b/ui/src/components/dictionary/sample.tsx @@ -1,23 +1,24 @@ import { cn } from "@/lib/utils"; -import { DictionarySample, DictionaryToken } from "@/types/dictionary"; -import { mergeUint8Arrays, zip } from "@/utils/array"; +import { DictionarySampleCompact } from "@/types/dictionary"; +import { zip } from "@/utils/array"; import { useState } from "react"; import { ColumnDef } from "@tanstack/react-table"; import { DataTable } from "../ui/data-table"; import { getAccentClassname } from "@/utils/style"; -import { SimpleSampleArea } from "../app/sample"; +import { Sample } from "../app/sample"; import { HoverCard, HoverCardContent } from "../ui/hover-card"; import { HoverCardTrigger } from "@radix-ui/react-hover-card"; import { FeatureLinkWithPreview } from "../app/feature-preview"; import { Trash2 } from "lucide-react"; +import { countTokenGroupPositions, groupToken, hex } from "@/utils/token"; -export type DictionarySampleAreaProps = { - samples: DictionarySample[]; - onSamplesChange?: (samples: DictionarySample[]) => void; +export type DictionarySampleProps = { + samples: DictionarySampleCompact[]; + onSamplesChange?: (samples: DictionarySampleCompact[]) => void; dictionaryName: string; }; -export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName }: DictionarySampleAreaProps) => { +export const DictionarySample = ({ samples, onSamplesChange, dictionaryName }: DictionarySampleProps) => { const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState<[number, number][]>([]); const toggleSelectedTokenGroupIndex = (sampleIndex: number, tokenGroupIndex: number) => { setSelectedTokenGroupIndices((prev) => @@ -29,47 +30,24 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName ); }; - const decoder = new TextDecoder("utf-8", { fatal: true }); - const tokens = samples.map((sample) => - sample.context.map((token, i) => ({ - token, - featureActs: zip(sample.featureActsIndices[i], sample.featureActs[i], sample.maxFeatureActs[i]).map( - ([featureActIndex, featureAct, maxFeatureAct]) => ({ - featureActIndex, - featureAct, - maxFeatureAct, - }) - ), - })) + zip(sample.context, sample.featureActsIndices, sample.featureActs, sample.maxFeatureActs).map( + ([token, featureActsIndices, featureActs, maxFeatureActs]) => ({ + token, + featureActs: zip(featureActsIndices, featureActs, maxFeatureActs).map( + ([featureActIndex, featureAct, maxFeatureAct]) => ({ + featureActIndex, + featureAct, + maxFeatureAct, + }) + ), + }) + ) ); - const tokenGroups = tokens - .map((t) => - t.reduce<[DictionaryToken[][], DictionaryToken[]]>( - ([groups, currentGroup], token) => { - const newGroup = [...currentGroup, token]; - try { - decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token))); - return [[...groups, newGroup], []]; - } catch { - return [groups, newGroup]; - } - }, - [[], []] - ) - ) - .map((v) => v[0]); + const tokenGroups = tokens.map(groupToken); - const tokenGroupPositions = tokenGroups.map((tokenGroupRow) => - tokenGroupRow.reduce( - (acc, tokenGroup) => { - const tokenCount = tokenGroup.length; - return [...acc, acc[acc.length - 1] + tokenCount]; - }, - [0] - ) - ); + const tokenGroupPositions = tokenGroups.map(countTokenGroupPositions); const selectedTokenGroups = selectedTokenGroupIndices.map(([s, t]) => tokenGroups[s][t]); const selectedTokens = selectedTokenGroups.flatMap((tokens) => tokens); @@ -96,19 +74,13 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName accessorKey: `token${i}`, header: () => ( - - {token.token.reduce( - (acc, b) => - b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`, - "" - )} - + {hex(token)}
Position: {tokenGroupPositions[s][t] + inGroupIndex}
- (j === t ? "bg-orange-500" : "")} /> @@ -136,19 +108,17 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName ...featureAct, })) ) - .reduce( - (acc, featureAct) => { - // Group by featureActIndex - const key = featureAct.featureActIndex.toString(); - if (acc[key]) { - acc[key].push(featureAct); - } else { - acc[key] = [featureAct]; - } - return acc; - }, - {} as Record - ) || {} + .reduce((acc, featureAct) => { + // Group by featureActIndex + const key = featureAct.featureActIndex.toString(); + if (acc[key]) { + acc[key].push(featureAct); + } else { + acc[key] = [featureAct]; + } + return acc; + }, {} as Record) || + {} ) .sort( // Sort by sum of featureAct @@ -173,10 +143,10 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName return (
- {samples.map((sample, sampleIndex) => ( + {tokenGroups.map((tokenGroups, sampleIndex) => (
- cn( diff --git a/ui/src/components/feature/feature-card.tsx b/ui/src/components/feature/feature-card.tsx index fa41369..7df9730 100644 --- a/ui/src/components/feature/feature-card.tsx +++ b/ui/src/components/feature/feature-card.tsx @@ -1,4 +1,4 @@ -import { Feature, SampleSchema } from "@/types/feature"; +import { Feature, FeatureSampleCompactSchema } from "@/types/feature"; import { decode } from "@msgpack/msgpack"; import camelcaseKeys from "camelcase-keys"; import { useState } from "react"; @@ -45,7 +45,7 @@ const FeatureCustomInputArea = ({ feature }: { feature: Feature }) => { stopPaths: ["context"], }) ) - .then((res) => SampleSchema.parse(res)); + .then((res) => FeatureSampleCompactSchema.parse(res)); }, [customInput]); return ( diff --git a/ui/src/components/feature/sample.tsx b/ui/src/components/feature/sample.tsx index 1b03555..ea5f6d7 100644 --- a/ui/src/components/feature/sample.tsx +++ b/ui/src/components/feature/sample.tsx @@ -1,9 +1,11 @@ -import { Feature, Sample, Token } from "@/types/feature"; -import { SuperToken } from "./token"; -import { mergeUint8Arrays } from "@/utils/array"; +import { Feature, FeatureSampleCompact } from "@/types/feature"; import { useState } from "react"; import { AppPagination } from "../ui/pagination"; -import { Accordion, AccordionTrigger, AccordionContent, AccordionItem } from "../ui/accordion"; +import { countTokenGroupPositions, groupToken, hex } from "@/utils/token"; +import { zip } from "@/utils/array"; +import { getAccentClassname } from "@/utils/style"; +import { cn } from "@/lib/utils"; +import { Sample } from "../app/sample"; export const FeatureSampleGroup = ({ feature, @@ -20,7 +22,7 @@ export const FeatureSampleGroup = ({

Max Activation: {Math.max(...sampleGroup.samples[0].featureActs).toFixed(3)}

{sampleGroup.samples.slice((page - 1) * 10, page * 10).map((sample, i) => ( { + return ( +
+
Token:
+
{hex(token)}
+
Position:
+
{position}
+
Activation:
+
+ {token.featureAct.toFixed(3)} +
+
+ ); +}; + export type FeatureActivationSampleProps = { - sample: Sample; + sample: FeatureSampleCompact; sampleName: string; maxFeatureAct: number; }; @@ -40,98 +63,32 @@ export type FeatureActivationSampleProps = { export const FeatureActivationSample = ({ sample, sampleName, maxFeatureAct }: FeatureActivationSampleProps) => { const sampleMaxFeatureAct = Math.max(...sample.featureActs); - const decoder = new TextDecoder("utf-8", { fatal: true }); - - const start = Math.max(0); - const end = Math.min(sample.context.length); - const tokens = sample.context.slice(start, end).map((token, i) => ({ + const tokens = zip(sample.context, sample.featureActs).map(([token, featureAct]) => ({ token, - featureAct: sample.featureActs[start + i], + featureAct, })); - const [tokenGroups, _] = tokens.reduce<[Token[][], Token[]]>( - ([groups, currentGroup], token) => { - const newGroup = [...currentGroup, token]; - try { - decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token))); - return [[...groups, newGroup], []]; - } catch { - return [groups, newGroup]; - } - }, - [[], []] - ); + const tokenGroups = groupToken(tokens); + const tokenGroupPositions = countTokenGroupPositions(tokenGroups); - const tokenGroupPositions = tokenGroups.reduce( - (acc, tokenGroup) => { - const tokenCount = tokenGroup.length; - return [...acc, acc[acc.length - 1] + tokenCount]; - }, - [0] - ); - - const tokensList = tokens.map((t) => t.featureAct); - const startTrigger = Math.max(tokensList.indexOf(Math.max(...tokensList)) - 100, 0); - const endTrigger = Math.min(tokensList.indexOf(Math.max(...tokensList)) + 10, sample.context.length); - const tokensTrigger = sample.context.slice(startTrigger, endTrigger).map((token, i) => ({ - token, - featureAct: sample.featureActs[startTrigger + i], - })); - - const [tokenGroupsTrigger, __] = tokensTrigger.reduce<[Token[][], Token[]]>( - ([groups, currentGroup], token) => { - const newGroup = [...currentGroup, token]; - try { - decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token))); - return [[...groups, newGroup], []]; - } catch { - return [groups, newGroup]; - } - }, - [[], []] - ); - - const tokenGroupPositionsTrigger = tokenGroupsTrigger.reduce( - (acc, tokenGroup) => { - const tokenCount = tokenGroup.length; - return [...acc, acc[acc.length - 1] + tokenCount]; - }, - [0] - ); + const featureActs = tokenGroups.map((group) => Math.max(...group.map((token) => token.featureAct))); + const start = Math.max(featureActs.findIndex((act) => act === sampleMaxFeatureAct) - 60, 0); return ( -
- - - -
- {sampleName && {sampleName}: } - {startTrigger != 0 && ...} - {tokenGroupsTrigger.map((tokens, i) => ( - - ))} - {endTrigger != 0 && ...} -
-
- - {tokenGroups.map((tokens, i) => ( - - ))} - -
-
-
+ (token, j) => + } + tokenGroupClassName={(tokenGroup) => { + const tokenGroupMaxFeatureAct = Math.max(...tokenGroup.map((t) => t.featureAct)); + return cn( + tokenGroupMaxFeatureAct > 0 && "hover:underline cursor-pointer", + sampleMaxFeatureAct > 0 && tokenGroupMaxFeatureAct == sampleMaxFeatureAct && "font-bold", + getAccentClassname(tokenGroupMaxFeatureAct, maxFeatureAct, "bg") + ); + }} + foldedStart={start} + /> ); }; diff --git a/ui/src/components/feature/token.tsx b/ui/src/components/feature/token.tsx deleted file mode 100644 index 548c5c7..0000000 --- a/ui/src/components/feature/token.tsx +++ /dev/null @@ -1,89 +0,0 @@ -import { cn } from "@/lib/utils"; -import { Token } from "@/types/feature"; -import { mergeUint8Arrays } from "@/utils/array"; -import { HoverCard, HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; -import { Separator } from "../ui/separator"; -import { Fragment } from "react/jsx-runtime"; -import { getAccentClassname } from "@/utils/style"; - -export type TokenInfoProps = { - token: Token; - maxFeatureAct: number; - position: number; -}; - -export const TokenInfo = ({ token, maxFeatureAct, position }: TokenInfoProps) => { - const hex = token.token.reduce( - (acc, b) => (b < 32 || b > 126 ? `${acc}\\x${b.toString(16).padStart(2, "0")}` : `${acc}${String.fromCharCode(b)}`), - "" - ); - - return ( -
-
Token:
-
{hex}
-
Position:
-
{position}
-
Activation:
-
- {token.featureAct.toFixed(3)} -
-
- ); -}; - -export type SuperTokenProps = { - tokens: Token[]; - position: number; - maxFeatureAct: number; - sampleMaxFeatureAct: number; -}; - -export const SuperToken = ({ tokens, position, maxFeatureAct, sampleMaxFeatureAct }: SuperTokenProps) => { - const decoder = new TextDecoder("utf-8", { fatal: true }); - const displayText = decoder - .decode(mergeUint8Arrays(tokens.map((t) => t.token))) - .replace("\n", "⏎") - .replace("\t", "⇥") - .replace("\r", "↵"); - - const superTokenMaxFeatureAct = Math.max(...tokens.map((t) => t.featureAct)); - - const SuperTokenInner = () => { - return ( - 0 && "hover:shadow-lg hover:text-gray-600 cursor-pointer", - sampleMaxFeatureAct > 0 && superTokenMaxFeatureAct == sampleMaxFeatureAct && "font-bold", - getAccentClassname(superTokenMaxFeatureAct, maxFeatureAct, "bg") - )} - > - {displayText} - - ); - }; - - if (superTokenMaxFeatureAct === 0) { - return ; - } - - return ( - - - - - - {tokens.length > 1 && ( -
This super token is composed of the {tokens.length} tokens below:
- )} - {tokens.map((token, i) => ( - - - {i < tokens.length - 1 && } - - ))} -
-
- ); -}; diff --git a/ui/src/components/model/model-card.tsx b/ui/src/components/model/model-card.tsx new file mode 100644 index 0000000..d89ab70 --- /dev/null +++ b/ui/src/components/model/model-card.tsx @@ -0,0 +1,661 @@ +import { Fragment, useEffect, useMemo, useState } from "react"; +import { Button } from "../ui/button"; +import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; +import { Textarea } from "../ui/textarea"; +import { useAsyncFn, useMount } from "react-use"; +import camelcaseKeys from "camelcase-keys"; +import snakecaseKeys from "snakecase-keys"; +import { decode } from "@msgpack/msgpack"; +import { ModelGeneration, ModelGenerationSchema } from "@/types/model"; +import { Sample } from "../app/sample"; +import { cn } from "@/lib/utils"; +import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Label, LabelList, ResponsiveContainer } from "recharts"; +import { zip } from "@/utils/array"; +import { Input } from "../ui/input"; +import { countTokenGroupPositions, groupToken, hex } from "@/utils/token"; +import { getAccentClassname } from "@/utils/style"; +import { Separator } from "../ui/separator"; +import MultipleSelector from "../ui/multiple-selector"; +import { z } from "zod"; +import { Combobox } from "../ui/combobox"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "../ui/select"; +import { Ban, MoreHorizontal, Plus, Trash2, Wrench, X } from "lucide-react"; +import { ColumnDef } from "@tanstack/react-table"; +import { FeatureLinkWithPreview } from "../app/feature-preview"; +import { DataTable } from "../ui/data-table"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "../ui/dropdown-menu"; +import { useNavigate } from "react-router-dom"; +import { Switch } from "../ui/switch"; +import { Label as SLabel } from "../ui/label"; +import { Toggle } from "../ui/toggle"; + +const SAEInfo = ({ + saeInfo, + saeSettings, + onSteerFeature, + setSAESettings, +}: { + saeInfo: { + name: string; + featureActs: { + featureActIndex: number; + featureAct: number; + maxFeatureAct: number; + }[]; + }; + saeSettings: { sortedBySum: boolean }; + onSteerFeature?: (name: string, featureIndex: number) => void; + setSAESettings: (settings: { sortedBySum: boolean }) => void; +}) => { + const navigate = useNavigate(); + + const columns: ColumnDef<{ + featureActIndex: number; + featureAct: number; + maxFeatureAct: number; + }>[] = [ + { + accessorKey: "featureActIndex", + header: () => ( +
+ Feature +
+ ), + cell: ({ row }) => ( +
+ +
+ ), + }, + { + accessorKey: "featureAct", + header: () => ( +
+ Activation +
+ ), + cell: ({ row }) => ( +
+ {row.original.featureAct.toFixed(3)} +
+ ), + }, + { + id: "actions", + enableHiding: false, + cell: ({ row }) => ( + + + + + + Actions + onSteerFeature?.(saeInfo.name, row.original.featureActIndex)}> + Steer feature + + + { + navigate( + `/features?dictionary=${encodeURIComponent(saeInfo.name)}&featureIndex=${ + row.original.featureActIndex + }` + ); + }} + > + View feature + + + + ), + meta: { + headerClassName: "w-16", + cellClassName: "py-0", + }, + }, + ]; + + return ( +
+
+
Features from {saeInfo.name}:
+ setSAESettings({ ...saeSettings, sortedBySum: pressed })} + > + Sort by sum + +
+ +
+ ); +}; + +const LogitsInfo = ({ + logits, +}: { + logits: { + logits: number; + token: Uint8Array; + }[]; +}) => { + const maxLogits = Math.max(...logits.map((logit) => logit.logits)); + const columns: ColumnDef<{ + logits: number; + token: Uint8Array; + }>[] = [ + { + accessorKey: "token", + header: () => ( +
+ Token +
+ ), + cell: ({ row }) => ( +
+ {hex(row.original)} +
+ ), + }, + { + accessorKey: "logits", + header: () => ( +
+ Logits +
+ ), + cell: ({ row }) => ( +
+ {row.original.logits.toFixed(3)} +
+ ), + }, + ]; + + return ( +
+
Logits:
+ +
+ ); +}; + +const LogitsBarChart = ({ + tokens, +}: { + tokens: { + logits: { + logits: number; + token: Uint8Array; + }[]; + token: Uint8Array; + }[]; +}) => { + const data = tokens.map((token) => + Object.assign( + {}, + ...token.logits.map((logit, j) => ({ + [`logits-${j}`]: logit.logits, + [`logits-token-${j}`]: hex(logit), + })), + { + name: hex(token), + } + ) + ); + + const colors = ["#8884d8", "#82ca9d", "#ffc658", "#ff7300", "#d6d6d6"]; + + return ( + + + + + + + {tokens[0].logits.slice(0, 5).map((_, i) => ( + + + + ))} + + + ); +}; + +const ModelSample = ({ + sample, + onSteerFeature, +}: { + sample: ModelGeneration; + onSteerFeature?: (name: string, featureIndex: number) => void; +}) => { + const [selectedTokenGroupIndices, setSelectedTokenGroupIndices] = useState([]); + const toggleSelectedTokenGroupIndex = (tokenGroupIndex: number) => { + setSelectedTokenGroupIndices((prev) => + prev.includes(tokenGroupIndex) + ? prev.filter((t) => t !== tokenGroupIndex) + : [...prev, tokenGroupIndex].sort((a, b) => a - b) + ); + }; + + const [saeSettings, setSAESettings] = useState<{ [name: string]: { sortedBySum: boolean } }>({}); + const getSAESettings = (name: string) => saeSettings[name] || { sortedBySum: false }; + + useEffect(() => { + setSelectedTokenGroupIndices([]); + }, [sample]); + + const tokens = useMemo(() => { + const saeInfo = + sample.saeInfo.length > 0 + ? zip( + ...sample.saeInfo.map((sae) => + zip(sae.featureActsIndices, sae.featureActs, sae.maxFeatureActs).map( + ([featureActIndex, featureAct, maxFeatureAct]) => ({ + name: sae.name, + featureActs: zip(featureActIndex, featureAct, maxFeatureAct).map( + ([featureActIndex, featureAct, maxFeatureAct]) => ({ + featureActIndex, + featureAct, + maxFeatureAct, + }) + ), + }) + ) + ) + ) + : sample.context.map(() => []); + + return zip(sample.context, sample.inputMask, sample.logits, sample.logitsTokens, saeInfo).map( + ([token, inputMask, logits, logitsTokens, saeInfo]) => ({ + token, + inputMask, + logits: zip(logits, logitsTokens).map(([logits, token]) => ({ + logits, + token, + })), + saeInfo, + }) + ); + }, [sample]); + + type Token = (typeof tokens)[0]; + + const sortTokenInfo = (tokens: Token[]) => { + const featureActSum = tokens.reduce((acc, token) => { + token.saeInfo.forEach((saeInfo) => { + saeInfo.featureActs.forEach((featureAct) => { + acc[saeInfo.name] = acc[saeInfo.name] || {}; + acc[saeInfo.name][featureAct.featureActIndex.toString()] = + acc[saeInfo.name][featureAct.featureActIndex.toString()] || 0; + acc[saeInfo.name][featureAct.featureActIndex.toString()] += featureAct.featureAct; + }); + }); + return acc; + }, {} as { [name: string]: { [featureIndex: string]: number } }); + + return tokens.map((token) => ({ + ...token, + logits: token.logits.sort((a, b) => b.logits - a.logits), + saeInfo: token.saeInfo.map((saeInfo) => ({ + ...saeInfo, + featureActs: getSAESettings(saeInfo.name).sortedBySum + ? saeInfo.featureActs.sort( + (a, b) => + featureActSum[saeInfo.name][b.featureActIndex.toString()] - + featureActSum[saeInfo.name][a.featureActIndex.toString()] + ) + : saeInfo.featureActs.sort((a, b) => b.featureAct - a.featureAct), + })), + })); + }; + + const tokenGroups = groupToken(tokens); + const tokenGroupPositions = countTokenGroupPositions(tokenGroups); + const selectedTokenGroups = selectedTokenGroupIndices.map((i) => tokenGroups[i]); + const selectedTokens = sortTokenInfo(selectedTokenGroups.flatMap((t) => t)); + const selectedTokenGroupPositions = selectedTokenGroupIndices.map((i) => tokenGroupPositions[i]); + const selectedTokenPositions = selectedTokenGroups.flatMap((t, i) => + t.map((_, j) => selectedTokenGroupPositions[i] + j) + ); + + return ( +
+ + cn( + "hover:shadow-lg hover:text-gray-600 cursor-pointer", + selectedTokenGroupIndices.some((t) => t === tokenIndex) && "bg-orange-500" + ) + } + tokenGroupProps={(_, i) => ({ + onClick: () => toggleSelectedTokenGroupIndex(i), + })} + /> + + {selectedTokens.length > 0 && ( +

+ Detail of {selectedTokens.length} Selected Token{selectedTokens.length > 1 ? "s" : ""}: +

+ )} + + {selectedTokens.map((token, i) => ( + +
+
+
Token:
+
{hex(token)}
+
Position:
+
{selectedTokenPositions[i]}
+
+
+
+ + {token.saeInfo.map((saeInfo, j) => ( + setSAESettings({ ...saeSettings, [saeInfo.name]: settings })} + /> + ))} +
+ {i < selectedTokens.length - 1 && } +
+ ))} + + {selectedTokens.length > 0 && } +
+ ); +}; + +const ModelCustomInputArea = () => { + const [customInput, setCustomInput] = useState(""); + const [doGenerate, setDoGenerate] = useState(true); + const [maxNewTokens, setMaxNewTokens] = useState(128); + const [topK, setTopK] = useState(50); + const [topP, setTopP] = useState(0.95); + const [selectedDictionaries, setSelectedDictionaries] = useState([]); + const [steerings, setSteerings] = useState< + { + sae: string | null; + featureIndex: number; + steeringType: "times" | "ablate" | "add" | "set"; + steeringValue: number | null; + }[] + >([{ sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }]); + + const [sample, setSample] = useState(null); + + const [dictionariesState, fetchDictionaries] = useAsyncFn(async () => { + return await fetch(`${import.meta.env.VITE_BACKEND_URL}/dictionaries`) + .then(async (res) => await res.json()) + .then((res) => z.array(z.string()).parse(res)); + }); + + useMount(async () => { + await fetchDictionaries(); + }); + + const [state, submit] = useAsyncFn(async () => { + if (!customInput) { + alert("Please enter your input."); + return; + } + const sample = await fetch(`${import.meta.env.VITE_BACKEND_URL}/model/generate`, { + method: "POST", + body: JSON.stringify( + snakecaseKeys({ + inputText: customInput, + maxNewTokens: doGenerate ? maxNewTokens : 0, + topK, + topP, + saes: selectedDictionaries, + steerings: steerings.filter((s) => s.sae !== null), + }) + ), + headers: { + Accept: "application/x-msgpack", + "Content-Type": "application/json", + }, + }) + .then(async (res) => { + if (!res.ok) { + throw new Error(await res.text()); + } + return res; + }) + .then(async (res) => await res.arrayBuffer()) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .then((res) => decode(new Uint8Array(res)) as any) + .then((res) => + camelcaseKeys(res, { + deep: true, + stopPaths: ["context", "logits_tokens"], + }) + ) + .then((res) => ModelGenerationSchema.parse(res)); + setSample(sample); + }, [doGenerate, customInput, maxNewTokens, topK, topP, selectedDictionaries]); + + return ( +
+

Generation

+
+ Do generate: +
+ + + {doGenerate + ? "Model will generate new contents based on the following configuration." + : "Model will only perceive the given input."} + +
+ Max new tokens: + setMaxNewTokens(parseInt(e.target.value))} + /> + Top K: + setTopK(parseInt(e.target.value))} + /> + Top P: + setTopP(parseFloat(e.target.value))} + /> + + SAEs: + ({ value: name, label: name })) || []} + commandProps={{ + className: "col-span-3", + }} + hidePlaceholderWhenSelected + placeholder="Bind SAEs to the language model to see features activated in the generation." + value={selectedDictionaries.map((name) => ({ value: name, label: name }))} + onChange={(value) => setSelectedDictionaries(value.map((v) => v.value))} + emptyIndicator={ +

No dictionaries found.

+ } + /> + {selectedDictionaries.length > 0 && + steerings.map((steering, i) => ( + + Steering {i + 1}: +
+
+ { + setSteerings((prev) => prev.map((s, j) => (i === j ? { ...s, sae: value } : s))); + }} + options={selectedDictionaries.map((name) => ({ value: name, label: name }))} + disabled={state.loading} + placeholder="Select a dictionary to steer the generation." + commandPlaceholder="Search for a selected dictionary..." + emptyIndicator="No options found" + /> + # + { + setSteerings((prev) => + prev.map((s, j) => (i === j ? { ...s, featureIndex: parseInt(e.target.value) } : s)) + ); + }} + /> + + { + setSteerings((prev) => + prev.map((s, j) => (i === j ? { ...s, steeringValue: parseFloat(e.target.value) } : s)) + ); + }} + /> +
+
+ steerings.length > 1 + ? setSteerings((prev) => prev.filter((_, j) => i !== j)) + : setSteerings([{ sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }]) + } + > + +
+
+ setSteerings((prev) => [ + ...prev.slice(0, i + 1), + { sae: null, featureIndex: 0, steeringType: "times", steeringValue: 1 }, + ...prev.slice(i + 1), + ]) + } + > + +
+
+
+ ))} +
+