diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index dd7e695..8b3bb91 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -2,42 +2,51 @@ from lm_saes.config import LanguageModelSAETrainingConfig from lm_saes.runner import language_model_sae_runner + cfg = LanguageModelSAETrainingConfig.from_flattened(dict( # LanguageModelConfig model_name = "gpt2", # The model name or path for the pre-trained model. d_model = 768, # The hidden size of the model. # TextDatasetConfig - dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. + dataset_path = 'Skylion007/OpenWebText', # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. is_dataset_tokenized = False, # Whether the dataset is tokenized. is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training. - concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. - context_size = 256, # The sequence length of the text dataset. - store_batch_size = 32, # The batch size for loading the corpus. + concat_tokens = True, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. + context_size = 1024, # The sequence length of the text dataset. + store_batch_size = 20, # The batch size for loading the corpus. # ActivationStoreConfig - hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. + hook_points = ['blocks.8.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`) n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary. # SAEConfig - hook_point_in = "blocks.3.hook_mlp_out", - hook_point_out = "blocks.3.hook_mlp_out", - expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. + hook_point_in = 'blocks.8.hook_resid_pre', + hook_point_out = 'blocks.8.hook_resid_pre', + use_decoder_bias = True, # Whether to use decoder bias. + expansion_factor = 128, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none". - decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. + decoder_exactly_fixed_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder. - l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations. + l1_coefficient = 2e-4, # The L1 regularization coefficient for the feature activations. + l1_coefficient_warmup_steps = 10000, # The number of warm-up steps for the L1 regularization coefficient. lp = 1, # The p-norm to use for the L1 regularization. - use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features. + use_ghost_grads = False, # Whether to use the ghost gradients for saving dead features. + init_decoder_norm = None, # The initial norm of the decoder. If None, the decoder will be initialized automatically with the lowest MSE. + init_encoder_with_decoder_transpose = True, + apply_decoder_bias_to_pre_encoder = True, + sparsity_include_decoder_norm = True, # LanguageModelSAETrainingConfig - total_training_tokens = 1_600_000_000, # The total number of tokens to train the dictionary. - lr = 4e-4, # The learning rate for the dictionary training. - betas = (0, 0.9999), # The betas for the Adam optimizer. + total_training_tokens = 100_000_000, # The total number of tokens to train the dictionary. + lr = 1e-4, # The learning rate for the dictionary training. + betas = (0.9, 0.9999), # The betas for the Adam optimizer. + lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup". - lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate. - lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler. + lr_warm_up_steps = 2000, # The number of warm-up steps for the learning rate. + lr_cool_down_steps = 4000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler. + clip_grad_norm = 0.0, # The maximum gradient norm for clipping. If 0.0, no gradient clipping will be performed. train_batch_size = 4096, # The batch size for training the dictionary, i.e. the number of token activations in a batch. feature_sampling_window = 1000, # The window size for sampling the feature activations. dead_feature_window = 5000, # The window size for detecting the dead features. @@ -45,18 +54,20 @@ eval_frequency = 1000, # The step frequency for evaluating the dictionary. log_frequency = 100, # The step frequency for logging the training information (to wandb). n_checkpoints = 10, # The number of checkpoints to save during the training. + remove_gradient_parallel_to_decoder_directions = False, + # WandbConfig log_to_wandb = True, # Whether to log the training information to wandb. - wandb_project= "gpt2-sae", # The wandb project name. - + wandb_project= "test", # The wandb project name. + # RunnerConfig device = "cuda", # The device to place all torch tensors. seed = 42, # The random seed. dtype = torch.float32, # The torch data type of non-integer tensors. - exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. - exp_series = "default", + exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. + exp_series = "test", exp_result_dir = "results" )) diff --git a/src/lm_saes/activation/activation_source.py b/src/lm_saes/activation/activation_source.py index ce8a01c..d0b5d40 100644 --- a/src/lm_saes/activation/activation_source.py +++ b/src/lm_saes/activation/activation_source.py @@ -40,7 +40,7 @@ def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig): self.cfg = cfg def next(self) -> Dict[str, torch.Tensor] | None: - tokens = self.token_source.next(self.cfg.dataset.store_batch_size) + tokens = self.next_tokens(self.cfg.dataset.store_batch_size) if tokens is None: return None diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index 2cb7f5e..f825e54 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -17,6 +17,7 @@ def __init__( concat_tokens: list[bool], seq_len: int, sample_probs: list[float], + prepend_bos: list[bool] ): self.dataloader = dataloader self.model = model @@ -33,13 +34,14 @@ def __init__( self.resid = torch.tensor([], dtype=torch.long, device=self.device) self.sample_probs = sample_probs + self.prepend_bos = prepend_bos - def fill_with_one_batch(self, batch, pack) -> None: + def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None: if self.is_dataset_tokenized: tokens: torch.Tensor = batch["tokens"].to(self.device) else: - tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device) + tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device) if pack: while tokens.size(0) > 0: cur_tokens = tokens[0] @@ -81,7 +83,7 @@ def next(self, batch_size: int) -> torch.Tensor | None: else: return None - self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch]) + self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch], prepend_bos=self.prepend_bos[dataset_idx_to_fetch]) ret = self.token_buffer[:batch_size] self.token_buffer = self.token_buffer[batch_size:] @@ -120,4 +122,5 @@ def from_config(model: HookedTransformer, cfg: TextDatasetConfig): concat_tokens=cfg.concat_tokens, seq_len=cfg.context_size, sample_probs=cfg.sample_probs, + prepend_bos=cfg.prepend_bos ) \ No newline at end of file diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 785e6c2..808719e 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -106,6 +106,7 @@ class TextDatasetConfig(RunnerConfig): context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) + prepend_bos: List[bool] = field(default_factory=lambda: [False]) def __post_init__(self): super().__post_init__() @@ -115,10 +116,14 @@ def __post_init__(self): if isinstance(self.concat_tokens, bool): self.concat_tokens = [self.concat_tokens] + if isinstance(self.prepend_bos, bool): + self.prepend_bos = [self.prepend_bos] + self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs] assert len(self.sample_probs) == len(self.dataset_path), "Number of sample_probs must match number of dataset paths" assert len(self.concat_tokens) == len(self.dataset_path), "Number of concat_tokens must match number of dataset paths" + assert len(self.prepend_bos) == len(self.dataset_path), "Number of prepend_bos must match number of dataset paths" @dataclass(kw_only=True) @@ -163,21 +168,26 @@ class SAEConfig(BaseModelConfig): sae_pretrained_name_or_path: Optional[str] = None strict_loading: bool = True - use_decoder_bias: bool = False + use_decoder_bias: bool = True apply_decoder_bias_to_pre_encoder: bool = True # set to False when training transcoders - decoder_bias_init_method: str = "geometric_median" - expansion_factor: int = 32 + expansion_factor: int = 128 d_model: int = 768 d_sae: int = None # type: ignore """ The dimension of the SAE, i.e. the number of dictionary components (or features). If None, it will be set to d_model * expansion_factor """ - norm_activation: str = "token-wise" # none, token-wise, batch-wise - decoder_exactly_unit_norm: bool = True + norm_activation: str = "token-wise" # none, token-wise, batch-wise, dataset-wise + dataset_average_activation_norm: Dict[str, float] | None = None + decoder_exactly_fixed_norm: bool = False + sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm use_glu_encoder: bool = False + init_decoder_norm: float | None = None # type: ignore + init_encoder_norm: float | None = None # type: ignore + init_encoder_with_decoder_transpose: bool = True l1_coefficient: float = 0.00008 + l1_coefficient_warmup_steps: int = 0 lp: int = 1 - use_ghost_grads: bool = True + use_ghost_grads: bool = False def __post_init__(self): super().__post_init__() @@ -185,6 +195,15 @@ def __post_init__(self): self.hook_point_out = self.hook_point_in if self.d_sae is None: self.d_sae = self.d_model * self.expansion_factor + if self.norm_activation == "dataset-wise" and self.dataset_average_activation_norm is None: + print(f'dataset_average_activation_norm is None and norm_activation is "dataset-wise". Will be computed automatically from the dataset.') + if self.sparsity_include_decoder_norm and self.decoder_exactly_fixed_norm: + raise ValueError("sparsity_include_decoder_norm and decoder_exactly_fixed_norm are incompatible.") + if self.sparsity_include_decoder_norm and self.use_ghost_grads: + raise ValueError("sparsity_include_decoder_norm and use_ghost_grads are incompatible.") + if self.init_encoder_with_decoder_transpose and isinstance(self.init_encoder_norm, float): + raise ValueError("init_encoder_with_decoder_transpose and init_encoder_norm with float are incompatible.") + @staticmethod def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, **kwargs): @@ -269,6 +288,8 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): lr_warm_up_steps: int = 5000 lr_cool_down_steps: int = 10000 train_batch_size: int = 4096 + clip_grad_norm: float = 0.0 + remove_gradient_parallel_to_decoder_directions: bool = False finetuning: bool = False diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index e5f0224..6c2d1a7 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -30,15 +30,8 @@ from lm_saes.analysis.features_to_logits import features_to_logits -def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() +def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -72,6 +65,25 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + if ( + cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None + or cfg.sae.init_decoder_norm is None + ): + assert not cfg.finetuning + sae = SparseAutoEncoder.from_initialization_searching( + activation_store=activation_store, + cfg=cfg, + ) + else: + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb_config: dict = { **asdict(cfg), diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 6dda1f6..9db2b3c 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -1,6 +1,6 @@ from importlib.metadata import version import os -from typing import Dict, Literal, Union, overload +from typing import Dict, Literal, Union, overload, List import torch import math from einops import einsum @@ -9,7 +9,8 @@ import safetensors.torch as safe -from lm_saes.config import SAEConfig +from lm_saes.config import SAEConfig, LanguageModelSAETrainingConfig +from lm_saes.activation.activation_store import ActivationStore from lm_saes.utils.huggingface import parse_pretrained_name_or_path class SparseAutoEncoder(HookedRootModule): @@ -33,28 +34,26 @@ def __init__( super(SparseAutoEncoder, self).__init__() self.cfg = cfg + self.current_l1_coefficient = cfg.l1_coefficient self.encoder = torch.nn.Parameter(torch.empty((cfg.d_model, cfg.d_sae), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.kaiming_uniform_(self.encoder) if cfg.use_glu_encoder: self.encoder_glu = torch.nn.Parameter(torch.empty((cfg.d_model, cfg.d_sae), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.kaiming_uniform_(self.encoder_glu) self.encoder_bias_glu = torch.nn.Parameter(torch.empty((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.zeros_(self.encoder_bias_glu) self.feature_act_mask = torch.nn.Parameter(torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) self.feature_act_scale = torch.nn.Parameter(torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) self.decoder = torch.nn.Parameter(torch.empty((cfg.d_sae, cfg.d_model), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.kaiming_uniform_(self.decoder) - self.set_decoder_norm_to_unit_norm() + + if cfg.use_decoder_bias: self.decoder_bias = torch.nn.Parameter(torch.empty((cfg.d_model,), dtype=cfg.dtype, device=cfg.device)) self.encoder_bias = torch.nn.Parameter(torch.empty((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.zeros_(self.encoder_bias) + self.train_base_parameters() @@ -62,6 +61,29 @@ def __init__( self.hook_feature_acts = HookPoint() self.hook_reconstructed = HookPoint() + self.initialize_parameters() + + + def initialize_parameters(self): + torch.nn.init.kaiming_uniform_(self.encoder) + + if self.cfg.use_glu_encoder: + torch.nn.init.kaiming_uniform_(self.encoder_glu) + torch.nn.init.zeros_(self.encoder_bias_glu) + + torch.nn.init.kaiming_uniform_(self.decoder) + self.set_decoder_norm_to_fixed_norm(self.cfg.init_decoder_norm, force_exact=True) + + if self.cfg.use_decoder_bias: + torch.nn.init.zeros_(self.decoder_bias) + torch.nn.init.zeros_(self.encoder_bias) + + if self.cfg.init_encoder_with_decoder_transpose: + self.encoder.data = self.decoder.data.T.clone().contiguous() + else: + self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm) + + def train_base_parameters(self): """Set the base parameters to be trained. """ @@ -96,7 +118,7 @@ def train_finetune_for_suppression_parameters(self): p.requires_grad_(True) - def compute_norm_factor(self, x: torch.Tensor) -> torch.Tensor: + def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> float | torch.Tensor: """Compute the normalization factor for the activation vectors. """ @@ -105,6 +127,9 @@ def compute_norm_factor(self, x: torch.Tensor) -> torch.Tensor: return math.sqrt(self.cfg.d_model) / torch.norm(x, 2, dim=-1, keepdim=True) elif self.cfg.norm_activation == "batch-wise": return math.sqrt(self.cfg.d_model) / torch.norm(x, 2, dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + elif self.cfg.norm_activation == "dataset-wise": + assert self.cfg.dataset_average_activation_norm is not None, "dataset_average_activation_norm must be provided for dataset-wise normalization" + return math.sqrt(self.cfg.d_model) / self.cfg.dataset_average_activation_norm[hook_point] else: return torch.tensor(1.0, dtype=self.cfg.dtype, device=self.cfg.device) @@ -148,7 +173,7 @@ def encode( if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: x = x - self.decoder_bias - x = x * self.compute_norm_factor(x) + x = x * self.compute_norm_factor(x, hook_point='in') hidden_pre = einsum( x, @@ -165,7 +190,7 @@ def encode( hidden_pre_glu = torch.sigmoid(hidden_pre_glu) hidden_pre = hidden_pre * hidden_pre_glu - hidden_pre = hidden_pre / self.compute_norm_factor(label) + hidden_pre = hidden_pre / self.compute_norm_factor(label, hook_point='in') hidden_pre = self.hook_hidden_pre(hidden_pre) feature_acts = self.feature_act_mask * self.feature_act_scale * torch.clamp(hidden_pre, min=0.0) @@ -221,11 +246,11 @@ def compute_loss( if label is None: label = x - label_norm_factor = self.compute_norm_factor(label) + label_norm_factor = self.compute_norm_factor(label, hook_point='out') feature_acts, hidden_pre = self.encode(x, label, return_hidden_pre=True) - feature_acts_normed = feature_acts * label_norm_factor - hidden_pre_normed = hidden_pre * label_norm_factor + feature_acts_normed = feature_acts * label_norm_factor # (batch, d_sae) + # hidden_pre_normed = hidden_pre * label_norm_factor reconstructed = self.decode(feature_acts) reconstructed_normed = reconstructed * label_norm_factor @@ -236,7 +261,11 @@ def compute_loss( l_rec = (reconstructed_normed - label_normed).pow(2) / (label_normed - label_normed.mean(dim=0, keepdim=True)).pow(2).sum(dim=-1, keepdim=True).clamp(min=1e-8).sqrt() # l_l1: (batch,) - l_l1 = torch.norm(feature_acts_normed, p=self.cfg.lp, dim=-1) + if self.cfg.sparsity_include_decoder_norm: + l_l1 = torch.norm(feature_acts_normed * torch.norm(self.decoder, p=2, dim=1), p=self.cfg.lp, dim=-1) + else: + l_l1 = torch.norm(feature_acts_normed, p=self.cfg.lp, dim=-1) + l_ghost_resid = torch.tensor(0.0, dtype=self.cfg.dtype, device=self.cfg.device) @@ -268,7 +297,7 @@ def compute_loss( mse_rescaling_factor = (l_rec / (l_ghost_resid + 1e-6)).detach() l_ghost_resid = mse_rescaling_factor * l_ghost_resid - loss = l_rec.mean() + self.cfg.l1_coefficient * l_l1.mean() + l_ghost_resid.mean() + loss = l_rec.mean() + self.current_l1_coefficient * l_l1.mean() + l_ghost_resid.mean() if return_aux_data: aux_data = { @@ -296,15 +325,53 @@ def forward( reconstructed = self.decode(feature_acts) return reconstructed + + @torch.no_grad() + def update_l1_coefficient(self, training_step): + if self.cfg.l1_coefficient_warmup_steps <= 0: + return + self.current_l1_coefficient = min(1., training_step / self.cfg.l1_coefficient_warmup_steps) * self.cfg.l1_coefficient @torch.no_grad() - def set_decoder_norm_to_unit_norm(self): + def set_decoder_norm_to_fixed_norm(self, value: float | None = 1.0, force_exact: bool | None = None): + if value is None: + return decoder_norm = torch.norm(self.decoder, dim=1, keepdim=True) - if self.cfg.decoder_exactly_unit_norm: - self.decoder.data = self.decoder.data / decoder_norm + if force_exact is None: + force_exact = self.cfg.decoder_exactly_fixed_norm + if force_exact: + self.decoder.data = self.decoder.data * value / decoder_norm else: - # Set the norm of the decoder to not exceed 1 - self.decoder.data = self.decoder.data / torch.clamp(decoder_norm, min=1.0) + # Set the norm of the decoder to not exceed value + self.decoder.data = self.decoder.data * value / torch.clamp(decoder_norm, min=value) + + @torch.no_grad() + def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): + if self.cfg.use_glu_encoder: + raise NotImplementedError("GLU encoder not supported") + if value is None: + print(f'Encoder norm is not set to a fixed value, using random initialization.') + return + encoder_norm = torch.norm(self.encoder, dim=0, keepdim=True) # [1, d_sae] + self.encoder.data = self.encoder.data * value / encoder_norm + + + @torch.no_grad() + def transform_to_unit_decoder_norm(self): + """ + If we include decoder norm in the sparsity loss, the final decoder norm is not guaranteed to be 1. + We make an equivalent transformation to the decoder to make it unit norm. + See https://transformer-circuits.pub/2024/april-update/index.html#training-saes + """ + assert self.cfg.sparsity_include_decoder_norm, "Decoder norm is not included in the sparsity loss" + if self.cfg.use_glu_encoder: + raise NotImplementedError("GLU encoder not supported") + + decoder_norm = torch.norm(self.decoder, p=2, dim=1) # (d_sae,) + self.encoder.data = self.encoder.data * decoder_norm + self.decoder.data = self.decoder.data / decoder_norm[:, None] + + self.encoder_bias.data = self.encoder_bias.data * decoder_norm @torch.no_grad() @@ -398,6 +465,52 @@ def from_pretrained( cfg = SAEConfig.from_pretrained(pretrained_name_or_path, strict_loading=strict_loading, **kwargs) return SparseAutoEncoder.from_config(cfg) + + @torch.no_grad() + @staticmethod + def from_initialization_searching( + activation_store: ActivationStore, + cfg: LanguageModelSAETrainingConfig, + ): + test_batch = activation_store.next(batch_size=cfg.train_batch_size * 8) # just random hard code xd + activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore + + if cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None: + print(f'SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples.') + + average_in_norm, average_out_norm = activation_in.norm(p=2, dim=1).mean().item(), activation_out.norm(p=2, + dim=1).mean().item() + + print( + f'Average input activation norm: {average_in_norm}\nAverage output activation norm: {average_out_norm}') + cfg.sae.dataset_average_activation_norm = {'in': average_in_norm, 'out': average_out_norm} + + if cfg.sae.init_decoder_norm is None: + assert cfg.sae.sparsity_include_decoder_norm, 'Decoder norm must be included in sparsity loss' + if not cfg.sae.init_encoder_with_decoder_transpose or cfg.sae.hook_point_in != cfg.sae.hook_point_out: + raise NotImplementedError('Transcoders cannot be initialized automatically.') + print('SAE: Starting grid search for initial decoder norm.') + + test_sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + def grid_search_best_init_norm(search_range: List[float]) -> float: + losses: Dict[float, float] = {} + for norm in search_range: + test_sae.set_decoder_norm_to_fixed_norm(norm, force_exact=True) + test_sae.encoder.data = test_sae.decoder.data.T.clone().contiguous() + mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() # type: ignore + losses[norm] = mse + best_norm = min(losses, key=losses.get) # type: ignore + return best_norm + + best_norm_coarse = grid_search_best_init_norm(torch.linspace(0.1, 1, 10).numpy().tolist()) + best_norm_fine_grained = grid_search_best_init_norm(torch.linspace(best_norm_coarse - 0.09, best_norm_coarse + 0.1, 20).numpy().tolist()) + print(f'The best (i.e. lowest MSE) initialized norm is {best_norm_fine_grained}') + + test_sae.set_decoder_norm_to_fixed_norm(best_norm_fine_grained, force_exact=True) + test_sae.encoder.data = test_sae.decoder.data.T.clone().contiguous() + + return test_sae def save_pretrained( self, @@ -416,4 +529,13 @@ def save_pretrained( elif ckpt_path.endswith(".pt"): torch.save({"sae": self.state_dict(), "version": version("lm-saes")}, ckpt_path) else: - raise ValueError(f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats.") \ No newline at end of file + raise ValueError(f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats.") + + @property + def decoder_norm(self): + return torch.norm(self.decoder, p=2, dim=1).mean() + + @property + def encoder_norm(self): + return torch.norm(self.encoder, p=2, dim=0).mean() + diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 682b5e9..28ec0e4 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -73,6 +73,7 @@ def train_sae( pbar = tqdm(total=total_training_tokens, desc="Training SAE", smoothing=0.01) while n_training_tokens < total_training_tokens: sae.train() + sae.update_l1_coefficient(n_training_steps) # Get the next batch of activations batch = activation_store.next(batch_size=cfg.train_batch_size) assert batch is not None, "Activation store is empty" @@ -108,10 +109,15 @@ def train_sae( if cfg.finetuning: loss = loss_data['l_rec'].mean() loss.backward() - sae.remove_gradient_parallel_to_decoder_directions() + + if cfg.clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) + if cfg.remove_gradient_parallel_to_decoder_directions: + sae.remove_gradient_parallel_to_decoder_directions() optimizer.step() - sae.set_decoder_norm_to_unit_norm() + if not cfg.sae.sparsity_include_decoder_norm: + sae.set_decoder_norm_to_fixed_norm(1) with torch.no_grad(): act_freq_scores += (aux_data["feature_acts"].abs() > 0).float().sum(0) n_frac_active_tokens += activation_in.size(0) @@ -143,7 +149,7 @@ def train_sae( act_freq_scores = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device) n_frac_active_tokens = torch.tensor([0], device=cfg.sae.device, dtype=torch.int) - if ((n_training_steps + 1) % cfg.log_frequency == 0): + if (n_training_steps + 1) % cfg.log_frequency == 0: # metrics for currents acts l0 = (aux_data["feature_acts"] > 0).float().sum(-1).mean() l_rec = loss_data["l_rec"].mean() @@ -198,7 +204,13 @@ def train_sae( # "metrics/mean_thomson_potential": mean_thomson_potential.item(), "metrics/l2_norm_error": l2_norm_error.item(), "metrics/l2_norm_error_ratio": l2_norm_error_ratio.item(), + # norm + "metrics/decoder_norm": sae.decoder_norm.item(), + "metrics/encoder_norm": sae.encoder_norm.item(), + "metrics/decoder_bias_mean": sae.decoder_bias.mean().item() if sae.cfg.use_decoder_bias else 0, + "metrics/enocder_bias_mean": sae.encoder_bias.mean().item(), # sparsity + "sparsity/l1_coefficient": sae.current_l1_coefficient, "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(), "sparsity/dead_features": ghost_grad_neuron_mask.sum().item(), "sparsity/useful_features": sae.decoder.norm(p=2, dim=1).gt(0.99).sum().item(), @@ -230,7 +242,8 @@ def train_sae( path = os.path.join( cfg.exp_result_dir, cfg.exp_name, "checkpoints", f"{n_training_steps}.safetensors" ) - sae.set_decoder_norm_to_unit_norm() + if not cfg.sae.sparsity_include_decoder_norm: + sae.set_decoder_norm_to_fixed_norm(1) sae.save_pretrained(path) checkpoint_thresholds.pop(0) @@ -253,7 +266,10 @@ def train_sae( path = os.path.join( cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors" ) - sae.set_decoder_norm_to_unit_norm() + if cfg.sae.sparsity_include_decoder_norm: + sae.transform_to_unit_decoder_norm() + else: + sae.set_decoder_norm_to_fixed_norm(1) sae.save_pretrained(path) @torch.no_grad() @@ -325,4 +341,4 @@ def prune_sae( ) sae.save_pretrained(path) - return sae \ No newline at end of file + return sae