Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…SAEs into ft4supp
  • Loading branch information
Hzfinfdu committed Aug 7, 2024
2 parents 5ef6d09 + f80cb0e commit 6c88ed6
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 91 deletions.
12 changes: 11 additions & 1 deletion src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def __post_init__(self):
self.lr_cool_down_steps = int(self.lr_cool_down_steps * total_training_steps)
print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}")
if self.finetuning:
assert self.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning."
assert self.sae.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning."

@dataclass(kw_only=True)
class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig):
Expand Down Expand Up @@ -476,6 +476,16 @@ class LanguageModelSAEAnalysisConfig(RunnerConfig):
}
)

n_sae_chunks: int = (
1 # Number of chunks to split the SAE into for analysis. For large models and SAEs, this can be useful to avoid memory issues.
)

def __post_init__(self):
super().__post_init__()
assert (
self.sae.d_sae % self.n_sae_chunks == 0
), f"d_sae ({self.sae.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})"


@dataclass(kw_only=True)
class FeaturesDecoderConfig(RunnerConfig):
Expand Down
23 changes: 6 additions & 17 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
if is_master():
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

hf_model = AutoModelForCausalLM.from_pretrained(
(
Expand Down Expand Up @@ -85,24 +90,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if not cfg.finetuning and (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

if is_master():
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and is_master():
wandb_config: dict = {
Expand Down Expand Up @@ -392,6 +380,7 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig):
del activation_store
torch.cuda.empty_cache()


@torch.no_grad()
def features_to_logits_runner(cfg: FeaturesDecoderConfig):
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)
Expand Down
88 changes: 50 additions & 38 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from builtins import print
from importlib.metadata import version
import os
from typing import Dict, Literal, Union, overload, List
Expand Down Expand Up @@ -117,7 +118,9 @@ def initialize_parameters(self):
if self.cfg.init_encoder_with_decoder_transpose:
self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous()
else:
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm)
self.set_encoder_norm_to_fixed_norm(
self.cfg.init_encoder_norm, during_init=True
)

def train_base_parameters(self):
"""Set the base parameters to be trained."""
Expand Down Expand Up @@ -261,7 +264,7 @@ def encode(

if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
x = (
x - self.decoder.bias.to_local() # type: ignore
x - self.decoder.bias.to_local() # type: ignore
if self.cfg.tp_size > 1
else x - self.decoder.bias
)
Expand Down Expand Up @@ -479,44 +482,39 @@ def set_decoder_norm_to_fixed_norm(
decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init)
if force_exact is None:
force_exact = self.cfg.decoder_exactly_fixed_norm

if self.cfg.tp_size > 1 and not during_init:
decoder_norm = distribute_tensor(
decoder_norm,
device_mesh=self.device_mesh["tp"],
placements=[Shard(0)],
)

if force_exact:
self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm
self.decoder.weight.data *= value / decoder_norm
else:
# Set the norm of the decoder to not exceed value
self.decoder.weight.data = (
self.decoder.weight.data * value / torch.clamp(decoder_norm, min=value)
)
self.decoder.weight.data *= value / torch.clamp(decoder_norm, min=value)

@torch.no_grad()
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0):
def set_encoder_norm_to_fixed_norm(
self, value: float | None = 1.0, during_init: bool = False
):
if self.cfg.use_glu_encoder:
raise NotImplementedError("GLU encoder not supported")
if value is None:
print(
f"Encoder norm is not set to a fixed value, using random initialization."
)
return
encoder_norm = self.encoder_norm(keepdim=True)
self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm

@torch.no_grad()
def transform_to_unit_decoder_norm(self):
"""
If we include decoder norm in the sparsity loss, the final decoder norm is not guaranteed to be 1.
We make an equivalent transformation to the decoder to make it unit norm.
See https://transformer-circuits.pub/2024/april-update/index.html#training-saes
"""
assert (
self.cfg.sparsity_include_decoder_norm
), "Decoder norm is not included in the sparsity loss"
if self.cfg.use_glu_encoder:
raise NotImplementedError("GLU encoder not supported")

decoder_norm = self.decoder_norm() # (d_sae,)
self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None]
self.decoder.weight.data = self.decoder.weight.data / decoder_norm

self.encoder.bias.data = self.encoder.bias.data * decoder_norm
encoder_norm = self.encoder_norm(keepdim=True, during_init=during_init)
if self.cfg.tp_size > 1 and not during_init:
encoder_norm = distribute_tensor(
encoder_norm,
device_mesh=self.device_mesh["tp"],
placements=[Shard(0)],
)
self.encoder.weight.data *= (value / encoder_norm)

@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
Expand Down Expand Up @@ -622,9 +620,7 @@ def from_initialization_searching(
activation_store: ActivationStore,
cfg: LanguageModelSAETrainingConfig,
):
test_batch = activation_store.next(
batch_size=cfg.train_batch_size
)
test_batch = activation_store.next(batch_size=cfg.train_batch_size)
activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore

if (
Expand Down Expand Up @@ -717,11 +713,27 @@ def save_pretrained(self, ckpt_path: str) -> None:
if os.path.isdir(ckpt_path):
ckpt_path = os.path.join(ckpt_path, "sae_weights.safetensors")
state_dict = self.get_full_state_dict()

@torch.no_grad()
def transform_to_unit_decoder_norm(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
decoder_norm = torch.norm(
state_dict["decoder.weight"], p=2, dim=0, keepdim=False
)
state_dict["decoder.weight"] = state_dict["decoder.weight"] / decoder_norm
state_dict["encoder.weight"] = (
state_dict["encoder.weight"] * decoder_norm[:, None]
)
state_dict["encoder.bias"] = state_dict["encoder.bias"] * decoder_norm
return state_dict

if self.cfg.sparsity_include_decoder_norm:
state_dict = transform_to_unit_decoder_norm(state_dict)

if is_master():
if ckpt_path.endswith(".safetensors"):
safe.save_file(
state_dict, ckpt_path, {"version": version("lm-saes")}
)
safe.save_file(state_dict, ckpt_path, {"version": version("lm-saes")})
elif ckpt_path.endswith(".pt"):
torch.save(
{"sae": state_dict, "version": version("lm-saes")}, ckpt_path
Expand All @@ -737,8 +749,8 @@ def decoder_norm(self, keepdim: bool = False, during_init: bool = False):
return torch.norm(self.decoder.weight, p=2, dim=0, keepdim=keepdim)
else:
decoder_norm = torch.norm(
self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim # type: ignore
)
self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim # type: ignore
)
decoder_norm = DTensor.from_local(
decoder_norm,
device_mesh=self.device_mesh["tp"],
Expand All @@ -758,8 +770,8 @@ def encoder_norm(
return torch.norm(self.encoder.weight, p=2, dim=1, keepdim=keepdim)
else:
encoder_norm = torch.norm(
self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim # type: ignore
)
self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim # type: ignore
)
encoder_norm = DTensor.from_local(
encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)]
)
Expand Down
4 changes: 1 addition & 3 deletions src/lm_saes/sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,7 @@ def train_sae(
pbar.close()

# Save the final model
if cfg.sae.sparsity_include_decoder_norm:
sae.transform_to_unit_decoder_norm()
else:
if not cfg.sae.sparsity_include_decoder_norm:
sae.set_decoder_norm_to_fixed_norm(1)
path = os.path.join(
cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors"
Expand Down
2 changes: 2 additions & 0 deletions src/lm_saes/utils/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
from huggingface_hub import create_repo, upload_folder, snapshot_download
from lm_saes.utils.misc import print_once


def upload_pretrained_sae_to_hf(sae_path: str, repo_id: str, private: bool = False):
Expand Down Expand Up @@ -54,6 +55,7 @@ def parse_pretrained_name_or_path(pretrained_name_or_path: str):
if os.path.exists(pretrained_name_or_path):
return pretrained_name_or_path
else:
print_once(f'Local path `{pretrained_name_or_path}` not found. Downloading from huggingface model hub.')
repo_id = "/".join(pretrained_name_or_path.split("/")[:2])
hook_point = "/".join(pretrained_name_or_path.split("/")[2:])
return download_pretrained_sae_from_hf(repo_id, hook_point)
60 changes: 60 additions & 0 deletions ui/src/components/app/section-navigator.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { cn } from "@/lib/utils";
import { Card, CardContent, CardHeader, CardTitle } from "../ui/card";
import { useEffect, useState } from "react";

export const SectionNavigator = ({ sections }: { sections: { title: string; id: string }[] }) => {
const [activeSection, setActiveSection] = useState<{ title: string; id: string } | null>(null);

const handleScroll = () => {
// Use reduce instead of find for obtaining the last section that is in view
const currentSection = sections.reduce((result: { title: string; id: string } | null, section) => {
const secElement = document.getElementById(section.id);
if (!secElement) return result;
const rect = secElement.getBoundingClientRect();
if (rect.top <= window.innerHeight / 2) {
return section;
}
return result;
}, null);

setActiveSection(currentSection);
};

useEffect(() => {
window.addEventListener("scroll", handleScroll);

// Run the handler to set the initial active section
handleScroll();

return () => {
window.removeEventListener("scroll", handleScroll);
};
});

return (
<Card className="py-4 sticky top-0 w-60 h-full bg-transparent">
<CardHeader className="py-0">
<CardTitle className="flex justify-between items-center text-xs p-2">
<span className="font-bold">CONTENTS</span>
</CardTitle>
</CardHeader>
<CardContent className="py-0">
<div className="flex flex-col">
<ul>
{sections.map((section) => (
<li key={section.id} className="relative">
<a
href={"#" + section.id}
className={cn("p-2 block text-neutral-700", activeSection === section && "text-[blue]")}
>
{section.title}
</a>
{activeSection === section && <div className="absolute -left-1.5 top-0 bottom-0 w-0.5 bg-[blue]"></div>}
</li>
))}
</ul>
</div>
</CardContent>
</Card>
);
};
24 changes: 13 additions & 11 deletions ui/src/components/dictionary/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName
...featureAct,
}))
)
.reduce((acc, featureAct) => {
// Group by featureActIndex
const key = featureAct.featureActIndex.toString();
if (acc[key]) {
acc[key].push(featureAct);
} else {
acc[key] = [featureAct];
}
return acc;
}, {} as Record<string, { token: Uint8Array; tokenIndex: number; featureAct: number; maxFeatureAct: number }[]>) ||
{}
.reduce(
(acc, featureAct) => {
// Group by featureActIndex
const key = featureAct.featureActIndex.toString();
if (acc[key]) {
acc[key].push(featureAct);
} else {
acc[key] = [featureAct];
}
return acc;
},
{} as Record<string, { token: Uint8Array; tokenIndex: number; featureAct: number; maxFeatureAct: number }[]>
) || {}
)
.sort(
// Sort by sum of featureAct
Expand Down
19 changes: 14 additions & 5 deletions ui/src/components/feature/feature-card.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
const [showCustomInput, setShowCustomInput] = useState<boolean>(false);

return (
<Card className="container">
<Card id="Interp." className="container">
<CardHeader>
<CardTitle className="flex justify-between items-center text-xl">
<span>
Expand All @@ -108,7 +108,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {

<FeatureInterpretation feature={feature} />

<div className="flex flex-col w-full gap-4">
<div id="Histogram" className="flex flex-col w-full gap-4">
<p className="font-bold">Activation Histogram</p>
<Plot
data={feature.featureActivationHistogram}
Expand All @@ -123,7 +123,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
</div>

{feature.logits && (
<div className="flex flex-col w-full gap-4">
<div id="Logits" className="flex flex-col w-full gap-4">
<p className="font-bold">Logits</p>
<div className="flex gap-4">
<div className="flex flex-col w-1/2 gap-4">
Expand Down Expand Up @@ -180,15 +180,24 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
</div>
)}

<div className="flex flex-col w-full gap-4">
<div id="Activation" className="flex flex-col w-full gap-4">
<Tabs defaultValue="top_activations">
<TabsList className="font-bold">
{feature.sampleGroups.map((sampleGroup) => (
{feature.sampleGroups.slice(0, feature.sampleGroups.length / 2).map((sampleGroup) => (
<TabsTrigger key={`tab-trigger-${sampleGroup.analysisName}`} value={sampleGroup.analysisName}>
{analysisNameMap(sampleGroup.analysisName)}
</TabsTrigger>
))}
</TabsList>
<TabsList className="font-bold">
{feature.sampleGroups
.slice(feature.sampleGroups.length / 2, feature.sampleGroups.length)
.map((sampleGroup) => (
<TabsTrigger key={`tab-trigger-${sampleGroup.analysisName}`} value={sampleGroup.analysisName}>
{analysisNameMap(sampleGroup.analysisName)}
</TabsTrigger>
))}
</TabsList>
{feature.sampleGroups.map((sampleGroup) => (
<TabsContent
key={`tab-content-${sampleGroup.analysisName}`}
Expand Down
Loading

0 comments on commit 6c88ed6

Please sign in to comment.