From b8aba9e360956a80a8d71b1f6089bdd79fb8fa0d Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Mon, 1 Jul 2024 19:13:19 +0800 Subject: [PATCH 1/7] feature(MEGA UPDATE): implement options in Anthropic April Update e.g. combine decoder bias norm with L1 --- examples/programmatic/train.py | 91 ++++++++++++++++++++------- src/lm_saes/config.py | 32 ++++++++-- src/lm_saes/runner.py | 31 ++++++--- src/lm_saes/sae.py | 112 +++++++++++++++++++++++++++------ src/lm_saes/sae_training.py | 100 +++++++++++++++++++++++++++-- 5 files changed, 303 insertions(+), 63 deletions(-) diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index dd7e695..2cf7dc8 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -2,42 +2,87 @@ from lm_saes.config import LanguageModelSAETrainingConfig from lm_saes.runner import language_model_sae_runner +# import argparse + + +# parser = argparse.ArgumentParser() +# parser.add_argument("--lr", type=float, default=4e-4) +# parser.add_argument("--l1_coef", type=float, default=8e-5) +# parser.add_argument("--sparsity_include_decoder_norm", action="store_true") +# parser.add_argument("--remove_gradient_parallel_to_decoder_directions", action="store_true") +# parser.add_argument("--use_decoder_bias", action="store_true") +# parser.add_argument("--init_encoder_with_decoder_transpose", action="store_true") +# args = parser.parse_args() + +# lr = args.lr +# l1_coefficient = args.l1_coef +# sparsity_include_decoder_norm = args.sparsity_include_decoder_norm +# remove_gradient_parallel_to_decoder_directions = args.remove_gradient_parallel_to_decoder_directions +# use_decoder_bias = args.use_decoder_bias +# init_encoder_with_decoder_transpose = args.init_encoder_with_decoder_transpose + cfg = LanguageModelSAETrainingConfig.from_flattened(dict( # LanguageModelConfig - model_name = "gpt2", # The model name or path for the pre-trained model. - d_model = 768, # The hidden size of the model. + model_name = "meta-llama/Meta-Llama-3-8B-Instruct", # The model name or path for the pre-trained model. + model_from_pretrained_path="/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct", # The path to load the pre-trained model. + d_model = 4096, # The hidden size of the model. # TextDatasetConfig - dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. + dataset_path = [ + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_arxiv_500k", + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_c4_500k", + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_stack_500k", + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_book_500k", + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_pile_500k", + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_wiki_500k", + "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/SFT_WildChatClean", + ], # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. + sample_probs = [1, 8, 1, 2, 4, 4, 1], is_dataset_tokenized = False, # Whether the dataset is tokenized. is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training. - concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. - context_size = 256, # The sequence length of the text dataset. - store_batch_size = 32, # The batch size for loading the corpus. + concat_tokens = [ + True, + True, + True, + True, + True, + True, + False, + ], # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. + context_size = 1024, # The sequence length of the text dataset. + store_batch_size = 20, # The batch size for loading the corpus. # ActivationStoreConfig - hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. + hook_points = ['blocks.16.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`) n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary. # SAEConfig - hook_point_in = "blocks.3.hook_mlp_out", - hook_point_out = "blocks.3.hook_mlp_out", + hook_point_in = 'blocks.16.hook_resid_pre', + hook_point_out = 'blocks.16.hook_resid_pre', + use_decoder_bias = True, # Whether to use decoder bias. expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none". - decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. + decoder_exactly_fixed_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder. - l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations. + l1_coefficient = 2e-4, # The L1 regularization coefficient for the feature activations. + l1_coefficient_warmup_steps = 10000, # The number of warm-up steps for the L1 regularization coefficient. lp = 1, # The p-norm to use for the L1 regularization. - use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features. + use_ghost_grads = False, # Whether to use the ghost gradients for saving dead features. + init_decoder_norm = 'auto', + init_encoder_with_decoder_transpose = True, + apply_decoder_bias_to_pre_encoder = True, + sparsity_include_decoder_norm = True, # LanguageModelSAETrainingConfig - total_training_tokens = 1_600_000_000, # The total number of tokens to train the dictionary. - lr = 4e-4, # The learning rate for the dictionary training. - betas = (0, 0.9999), # The betas for the Adam optimizer. + total_training_tokens = 300_000_000, # The total number of tokens to train the dictionary. + lr = 1e-4, # The learning rate for the dictionary training. + betas = (0.9, 0.9999), # The betas for the Adam optimizer. + lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup". - lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate. - lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler. + lr_warm_up_steps = 2000, # The number of warm-up steps for the learning rate. + lr_cool_down_steps = 4000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler. + clip_grad_norm = 0.0, # The maximum gradient norm for clipping. If 0.0, no gradient clipping will be performed. train_batch_size = 4096, # The batch size for training the dictionary, i.e. the number of token activations in a batch. feature_sampling_window = 1000, # The window size for sampling the feature activations. dead_feature_window = 5000, # The window size for detecting the dead features. @@ -45,18 +90,20 @@ eval_frequency = 1000, # The step frequency for evaluating the dictionary. log_frequency = 100, # The step frequency for logging the training information (to wandb). n_checkpoints = 10, # The number of checkpoints to save during the training. + remove_gradient_parallel_to_decoder_directions = False, + # WandbConfig log_to_wandb = True, # Whether to log the training information to wandb. - wandb_project= "gpt2-sae", # The wandb project name. - + wandb_project= "llama3-sae", # The wandb project name. + # RunnerConfig device = "cuda", # The device to place all torch tensors. seed = 42, # The random seed. - dtype = torch.float32, # The torch data type of non-integer tensors. + dtype = torch.bfloat16, # The torch data type of non-integer tensors. - exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. - exp_series = "default", + exp_name = f"llama-L16RPr", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. + exp_series = "llama3-sae", exp_result_dir = "results" )) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 785e6c2..00cf4ec 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -163,21 +163,26 @@ class SAEConfig(BaseModelConfig): sae_pretrained_name_or_path: Optional[str] = None strict_loading: bool = True - use_decoder_bias: bool = False + use_decoder_bias: bool = True apply_decoder_bias_to_pre_encoder: bool = True # set to False when training transcoders - decoder_bias_init_method: str = "geometric_median" - expansion_factor: int = 32 + expansion_factor: int = 128 d_model: int = 768 d_sae: int = None # type: ignore """ The dimension of the SAE, i.e. the number of dictionary components (or features). If None, it will be set to d_model * expansion_factor """ - norm_activation: str = "token-wise" # none, token-wise, batch-wise - decoder_exactly_unit_norm: bool = True + norm_activation: str = "token-wise" # none, token-wise, batch-wise, dataset-wise + dataset_average_activation_norm: Dict[str, float] | None = None + decoder_exactly_fixed_norm: bool = False + sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm use_glu_encoder: bool = False + init_decoder_norm: float | str = 'auto' + init_encoder_norm: float | str | None = 'auto' + init_encoder_with_decoder_transpose: bool = True l1_coefficient: float = 0.00008 + l1_coefficient_warmup_steps: int = 0 lp: int = 1 - use_ghost_grads: bool = True + use_ghost_grads: bool = False def __post_init__(self): super().__post_init__() @@ -185,6 +190,19 @@ def __post_init__(self): self.hook_point_out = self.hook_point_in if self.d_sae is None: self.d_sae = self.d_model * self.expansion_factor + if self.norm_activation == "dataset-wise" and self.dataset_average_activation_norm is None: + print(f'dataset_average_activation_norm is None and norm_activation is "dataset-wise". Will be computed automatically from the dataset.') + if self.sparsity_include_decoder_norm and self.decoder_exactly_fixed_norm: + raise ValueError("sparsity_include_decoder_norm and decoder_exactly_fixed_norm are incompatible.") + if self.sparsity_include_decoder_norm and self.use_ghost_grads: + raise ValueError("sparsity_include_decoder_norm and use_ghost_grads are incompatible.") + if isinstance(self.init_decoder_norm, str) and self.init_decoder_norm != 'auto': + raise ValueError("init_decoder_norm must be a float or 'auto'.") + if isinstance(self.init_encoder_norm, str) and self.init_encoder_norm != 'auto': + raise ValueError("init_encoder_norm must be None, a float or 'auto'.") + if self.init_encoder_with_decoder_transpose and isinstance(self.init_encoder_norm, float): + raise ValueError("init_encoder_with_decoder_transpose and init_encoder_norm with float are incompatible.") + @staticmethod def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, **kwargs): @@ -269,6 +287,8 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): lr_warm_up_steps: int = 5000 lr_cool_down_steps: int = 10000 train_batch_size: int = 4096 + clip_grad_norm: float = 0.0 + remove_gradient_parallel_to_decoder_directions: bool = False finetuning: bool = False diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index e5f0224..f993903 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -25,20 +25,13 @@ from lm_saes.sae import SparseAutoEncoder from lm_saes.activation.activation_dataset import make_activation_dataset from lm_saes.activation.activation_store import ActivationStore -from lm_saes.sae_training import prune_sae, train_sae +from lm_saes.sae_training import prune_sae, train_sae, init_sae_on_dataset from lm_saes.analysis.sample_feature_activations import sample_feature_activations from lm_saes.analysis.features_to_logits import features_to_logits -def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() +def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -72,6 +65,26 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + if ( + cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None + or cfg.sae.init_decoder_norm == 'auto' + ): + assert not cfg.finetuning + sae = init_sae_on_dataset( + model, + activation_store, + cfg + ) + else: + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb_config: dict = { **asdict(cfg), diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 6dda1f6..80dfaee 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -33,28 +33,26 @@ def __init__( super(SparseAutoEncoder, self).__init__() self.cfg = cfg + self.current_l1_coefficient = cfg.l1_coefficient self.encoder = torch.nn.Parameter(torch.empty((cfg.d_model, cfg.d_sae), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.kaiming_uniform_(self.encoder) if cfg.use_glu_encoder: self.encoder_glu = torch.nn.Parameter(torch.empty((cfg.d_model, cfg.d_sae), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.kaiming_uniform_(self.encoder_glu) self.encoder_bias_glu = torch.nn.Parameter(torch.empty((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.zeros_(self.encoder_bias_glu) self.feature_act_mask = torch.nn.Parameter(torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) self.feature_act_scale = torch.nn.Parameter(torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) self.decoder = torch.nn.Parameter(torch.empty((cfg.d_sae, cfg.d_model), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.kaiming_uniform_(self.decoder) - self.set_decoder_norm_to_unit_norm() + + if cfg.use_decoder_bias: self.decoder_bias = torch.nn.Parameter(torch.empty((cfg.d_model,), dtype=cfg.dtype, device=cfg.device)) self.encoder_bias = torch.nn.Parameter(torch.empty((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - torch.nn.init.zeros_(self.encoder_bias) + self.train_base_parameters() @@ -62,6 +60,29 @@ def __init__( self.hook_feature_acts = HookPoint() self.hook_reconstructed = HookPoint() + self.initialize_parameters() + + + def initialize_parameters(self): + torch.nn.init.kaiming_uniform_(self.encoder) + + if self.cfg.use_glu_encoder: + torch.nn.init.kaiming_uniform_(self.encoder_glu) + torch.nn.init.zeros_(self.encoder_bias_glu) + + torch.nn.init.kaiming_uniform_(self.decoder) + self.set_decoder_norm_to_fixed_norm(self.cfg.init_decoder_norm, force_exact=True) + + if self.cfg.use_decoder_bias: + torch.nn.init.zeros_(self.decoder_bias) + torch.nn.init.zeros_(self.encoder_bias) + + if self.cfg.init_encoder_with_decoder_transpose: + self.encoder.data = self.decoder.data.T.clone().contiguous() + else: + self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm) + + def train_base_parameters(self): """Set the base parameters to be trained. """ @@ -96,7 +117,7 @@ def train_finetune_for_suppression_parameters(self): p.requires_grad_(True) - def compute_norm_factor(self, x: torch.Tensor) -> torch.Tensor: + def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: """Compute the normalization factor for the activation vectors. """ @@ -105,6 +126,9 @@ def compute_norm_factor(self, x: torch.Tensor) -> torch.Tensor: return math.sqrt(self.cfg.d_model) / torch.norm(x, 2, dim=-1, keepdim=True) elif self.cfg.norm_activation == "batch-wise": return math.sqrt(self.cfg.d_model) / torch.norm(x, 2, dim=-1, keepdim=True).mean(dim=-2, keepdim=True) + elif self.cfg.norm_activation == "dataset-wise": + assert self.cfg.dataset_average_activation_norm is not None, "dataset_average_activation_norm must be provided for dataset-wise normalization" + return math.sqrt(self.cfg.d_model) / self.cfg.dataset_average_activation_norm[hook_point] else: return torch.tensor(1.0, dtype=self.cfg.dtype, device=self.cfg.device) @@ -148,7 +172,7 @@ def encode( if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: x = x - self.decoder_bias - x = x * self.compute_norm_factor(x) + x = x * self.compute_norm_factor(x, hook_point='in') hidden_pre = einsum( x, @@ -165,7 +189,7 @@ def encode( hidden_pre_glu = torch.sigmoid(hidden_pre_glu) hidden_pre = hidden_pre * hidden_pre_glu - hidden_pre = hidden_pre / self.compute_norm_factor(label) + hidden_pre = hidden_pre / self.compute_norm_factor(label, hook_point='in') hidden_pre = self.hook_hidden_pre(hidden_pre) feature_acts = self.feature_act_mask * self.feature_act_scale * torch.clamp(hidden_pre, min=0.0) @@ -221,11 +245,11 @@ def compute_loss( if label is None: label = x - label_norm_factor = self.compute_norm_factor(label) + label_norm_factor = self.compute_norm_factor(label, hook_point='out') feature_acts, hidden_pre = self.encode(x, label, return_hidden_pre=True) - feature_acts_normed = feature_acts * label_norm_factor - hidden_pre_normed = hidden_pre * label_norm_factor + feature_acts_normed = feature_acts * label_norm_factor # (batch, d_sae) + # hidden_pre_normed = hidden_pre * label_norm_factor reconstructed = self.decode(feature_acts) reconstructed_normed = reconstructed * label_norm_factor @@ -236,7 +260,11 @@ def compute_loss( l_rec = (reconstructed_normed - label_normed).pow(2) / (label_normed - label_normed.mean(dim=0, keepdim=True)).pow(2).sum(dim=-1, keepdim=True).clamp(min=1e-8).sqrt() # l_l1: (batch,) - l_l1 = torch.norm(feature_acts_normed, p=self.cfg.lp, dim=-1) + if self.cfg.sparsity_include_decoder_norm: + l_l1 = torch.norm(feature_acts_normed * torch.norm(self.decoder, p=2, dim=1), p=self.cfg.lp, dim=-1) + else: + l_l1 = torch.norm(feature_acts_normed, p=self.cfg.lp, dim=-1) + l_ghost_resid = torch.tensor(0.0, dtype=self.cfg.dtype, device=self.cfg.device) @@ -268,7 +296,7 @@ def compute_loss( mse_rescaling_factor = (l_rec / (l_ghost_resid + 1e-6)).detach() l_ghost_resid = mse_rescaling_factor * l_ghost_resid - loss = l_rec.mean() + self.cfg.l1_coefficient * l_l1.mean() + l_ghost_resid.mean() + loss = l_rec.mean() + self.current_l1_coefficient * l_l1.mean() + l_ghost_resid.mean() if return_aux_data: aux_data = { @@ -296,15 +324,51 @@ def forward( reconstructed = self.decode(feature_acts) return reconstructed + + @torch.no_grad() + def update_l1_coefficient(self, training_step): + if self.cfg.l1_coefficient_warmup_steps <= 0: + return + self.current_l1_coefficient = min(1., training_step / self.cfg.l1_coefficient_warmup_steps) * self.cfg.l1_coefficient @torch.no_grad() - def set_decoder_norm_to_unit_norm(self): + def set_decoder_norm_to_fixed_norm(self, value: float = 1.0, force_exact: bool | None = None): decoder_norm = torch.norm(self.decoder, dim=1, keepdim=True) - if self.cfg.decoder_exactly_unit_norm: - self.decoder.data = self.decoder.data / decoder_norm + if force_exact is None: + force_exact = self.cfg.decoder_exactly_fixed_norm + if force_exact: + self.decoder.data = self.decoder.data * value / decoder_norm else: - # Set the norm of the decoder to not exceed 1 - self.decoder.data = self.decoder.data / torch.clamp(decoder_norm, min=1.0) + # Set the norm of the decoder to not exceed value + self.decoder.data = self.decoder.data * value / torch.clamp(decoder_norm, min=value) + + @torch.no_grad() + def set_encoder_norm_to_fixed_norm(self, value: float = 1.0): + if self.cfg.use_glu_encoder: + raise NotImplementedError("GLU encoder not supported") + if value is None: + print(f'Encoder norm is not set to a fixed value, using random initialization.') + return + encoder_norm = torch.norm(self.encoder, dim=0, keepdim=True) # [1, d_sae] + self.encoder.data = self.encoder.data * value / encoder_norm + + + @torch.no_grad() + def transform_to_unit_decoder_norm(self): + """ + If we include decoder norm in the sparsity loss, the final decoder norm is not guaranteed to be 1. + We make an equivalent transformation to the decoder to make it unit norm. + See https://transformer-circuits.pub/2024/april-update/index.html#training-saes + """ + assert self.cfg.sparsity_include_decoder_norm, "Decoder norm is not included in the sparsity loss" + if self.cfg.use_glu_encoder: + raise NotImplementedError("GLU encoder not supported") + + decoder_norm = torch.norm(self.decoder, p=2, dim=1) # (d_sae,) + self.encoder.data = self.encoder.data * decoder_norm + self.decoder.data = self.decoder.data / decoder_norm[:, None] + + self.encoder_bias.data = self.encoder_bias.data * decoder_norm @torch.no_grad() @@ -416,4 +480,12 @@ def save_pretrained( elif ckpt_path.endswith(".pt"): torch.save({"sae": self.state_dict(), "version": version("lm-saes")}, ckpt_path) else: - raise ValueError(f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats.") \ No newline at end of file + raise ValueError(f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats.") + + @property + def decoder_norm(self): + return torch.norm(self.decoder, p=2, dim=1).mean() + + @property + def encoder_norm(self): + return torch.norm(self.encoder, p=2, dim=0).mean() diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 682b5e9..40139c9 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -73,6 +73,7 @@ def train_sae( pbar = tqdm(total=total_training_tokens, desc="Training SAE", smoothing=0.01) while n_training_tokens < total_training_tokens: sae.train() + sae.update_l1_coefficient(n_training_steps) # Get the next batch of activations batch = activation_store.next(batch_size=cfg.train_batch_size) assert batch is not None, "Activation store is empty" @@ -108,10 +109,15 @@ def train_sae( if cfg.finetuning: loss = loss_data['l_rec'].mean() loss.backward() - sae.remove_gradient_parallel_to_decoder_directions() + + if cfg.clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) + if cfg.remove_gradient_parallel_to_decoder_directions: + sae.remove_gradient_parallel_to_decoder_directions() optimizer.step() - sae.set_decoder_norm_to_unit_norm() + if not cfg.sae.sparsity_include_decoder_norm: + sae.set_decoder_norm_to_fixed_norm(1) with torch.no_grad(): act_freq_scores += (aux_data["feature_acts"].abs() > 0).float().sum(0) n_frac_active_tokens += activation_in.size(0) @@ -143,7 +149,7 @@ def train_sae( act_freq_scores = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device) n_frac_active_tokens = torch.tensor([0], device=cfg.sae.device, dtype=torch.int) - if ((n_training_steps + 1) % cfg.log_frequency == 0): + if (n_training_steps + 1) % cfg.log_frequency == 0: # metrics for currents acts l0 = (aux_data["feature_acts"] > 0).float().sum(-1).mean() l_rec = loss_data["l_rec"].mean() @@ -198,7 +204,13 @@ def train_sae( # "metrics/mean_thomson_potential": mean_thomson_potential.item(), "metrics/l2_norm_error": l2_norm_error.item(), "metrics/l2_norm_error_ratio": l2_norm_error_ratio.item(), + # norm + "metrics/decoder_norm": sae.decoder_norm.item(), + "metrics/encoder_norm": sae.encoder_norm.item(), + "metrics/decoder_bias_mean": sae.decoder_bias.mean().item() if sae.cfg.use_decoder_bias else 0, + "metrics/enocder_bias_mean": sae.encoder_bias.mean().item(), # sparsity + "sparsity/l1_coefficient": sae.current_l1_coefficient, "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(), "sparsity/dead_features": ghost_grad_neuron_mask.sum().item(), "sparsity/useful_features": sae.decoder.norm(p=2, dim=1).gt(0.99).sum().item(), @@ -230,7 +242,8 @@ def train_sae( path = os.path.join( cfg.exp_result_dir, cfg.exp_name, "checkpoints", f"{n_training_steps}.safetensors" ) - sae.set_decoder_norm_to_unit_norm() + if not cfg.sae.sparsity_include_decoder_norm: + sae.set_decoder_norm_to_fixed_norm(1) sae.save_pretrained(path) checkpoint_thresholds.pop(0) @@ -253,7 +266,10 @@ def train_sae( path = os.path.join( cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors" ) - sae.set_decoder_norm_to_unit_norm() + if cfg.sae.sparsity_include_decoder_norm: + sae.transform_to_unit_decoder_norm() + else: + sae.set_decoder_norm_to_fixed_norm(1) sae.save_pretrained(path) @torch.no_grad() @@ -325,4 +341,76 @@ def prune_sae( ) sae.save_pretrained(path) - return sae \ No newline at end of file + return sae + +@torch.no_grad() +def init_sae_on_dataset( + model: HookedTransformer, + activation_store: ActivationStore, + cfg: LanguageModelSAETrainingConfig, +): + test_batch = activation_store.next(batch_size=cfg.train_batch_size * 8) # just random hard code xd + activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[ + cfg.sae.hook_point_out] # [batch, d_model] + + if cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None: + print(f'SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples.') + + average_in_norm, average_out_norm = activation_in.norm(p=2, dim=1).mean().item(), activation_out.norm(p=2, + dim=1).mean().item() + + print(f'Average input activation norm: {average_in_norm}\nAverage output activation norm: {average_out_norm}') + cfg.sae.dataset_average_activation_norm = {'in': average_in_norm, 'out': average_out_norm} + + + if cfg.sae.init_decoder_norm == 'auto': + initializing_transcoder = False + assert cfg.sae.sparsity_include_decoder_norm, 'Decoder norm must be included in sparsity loss' + print('SAE: Starting grid search for initial decoder norm.') + if not cfg.sae.init_encoder_with_decoder_transpose: + assert cfg.sae.hook_point_in != cfg.sae.hook_point_out, 'If not training transcoder, it is recommended to init encoder with decoder transpose' + print('We are operating in 2-d grid search for encoder and decoder respectively.') + initializing_transcoder = True + + if not initializing_transcoder: + losses = {} + for norm in torch.linspace(0.1, 1, 10).numpy().tolist(): + cfg.sae.init_decoder_norm = norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[cfg.sae.init_decoder_norm] = mse + best_norm = min(losses, key=losses.get) + losses = {} + for norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): + cfg.sae.init_decoder_norm = best_norm + norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[cfg.sae.init_decoder_norm] = mse + best_norm = min(losses, key=losses.get) + print(f'The best (i.e. lowest MSE) initialized norm is {best_norm}') + cfg.sae.init_decoder_norm = best_norm + else: + losses = {} + for dec_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): + for enc_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): + cfg.sae.init_decoder_norm = dec_norm + cfg.sae.init_encoder_norm = enc_norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse + best_norms = min(losses, key=losses.get) + losses = {} + for dec_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): + for enc_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): + cfg.sae.init_decoder_norm = best_norms[0] + dec_norm + cfg.sae.init_encoder_norm = best_norms[1] + enc_norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse + best_norms = min(losses, key=losses.get) + print(f'The best (i.e. lowest MSE) initialized norms are (decoder, encoder): {best_norms}') + cfg.sae.init_decoder_norm = best_norms[0] + cfg.sae.init_encoder_norm = best_norms[1] + + return SparseAutoEncoder.from_config(cfg=cfg.sae) + From 7f4091cb8655abd9e135ffbfdeed224c45fe815b Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Mon, 1 Jul 2024 20:13:47 +0800 Subject: [PATCH 2/7] fix(sae): merge a standalone SAE init func with static method --- examples/programmatic/train.py | 64 +++++++--------------------- src/lm_saes/config.py | 8 +--- src/lm_saes/runner.py | 11 +++-- src/lm_saes/sae.py | 77 +++++++++++++++++++++++++++++++++- src/lm_saes/sae_training.py | 72 ------------------------------- 5 files changed, 97 insertions(+), 135 deletions(-) diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index 2cf7dc8..8b3bb91 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -2,66 +2,30 @@ from lm_saes.config import LanguageModelSAETrainingConfig from lm_saes.runner import language_model_sae_runner -# import argparse - - -# parser = argparse.ArgumentParser() -# parser.add_argument("--lr", type=float, default=4e-4) -# parser.add_argument("--l1_coef", type=float, default=8e-5) -# parser.add_argument("--sparsity_include_decoder_norm", action="store_true") -# parser.add_argument("--remove_gradient_parallel_to_decoder_directions", action="store_true") -# parser.add_argument("--use_decoder_bias", action="store_true") -# parser.add_argument("--init_encoder_with_decoder_transpose", action="store_true") -# args = parser.parse_args() - -# lr = args.lr -# l1_coefficient = args.l1_coef -# sparsity_include_decoder_norm = args.sparsity_include_decoder_norm -# remove_gradient_parallel_to_decoder_directions = args.remove_gradient_parallel_to_decoder_directions -# use_decoder_bias = args.use_decoder_bias -# init_encoder_with_decoder_transpose = args.init_encoder_with_decoder_transpose cfg = LanguageModelSAETrainingConfig.from_flattened(dict( # LanguageModelConfig - model_name = "meta-llama/Meta-Llama-3-8B-Instruct", # The model name or path for the pre-trained model. - model_from_pretrained_path="/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct", # The path to load the pre-trained model. - d_model = 4096, # The hidden size of the model. + model_name = "gpt2", # The model name or path for the pre-trained model. + d_model = 768, # The hidden size of the model. # TextDatasetConfig - dataset_path = [ - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_arxiv_500k", - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_c4_500k", - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_stack_500k", - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_book_500k", - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_pile_500k", - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_wiki_500k", - "/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/SFT_WildChatClean", - ], # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. - sample_probs = [1, 8, 1, 2, 4, 4, 1], + dataset_path = 'Skylion007/OpenWebText', # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. is_dataset_tokenized = False, # Whether the dataset is tokenized. is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training. - concat_tokens = [ - True, - True, - True, - True, - True, - True, - False, - ], # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. + concat_tokens = True, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. context_size = 1024, # The sequence length of the text dataset. store_batch_size = 20, # The batch size for loading the corpus. # ActivationStoreConfig - hook_points = ['blocks.16.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. + hook_points = ['blocks.8.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`) n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary. # SAEConfig - hook_point_in = 'blocks.16.hook_resid_pre', - hook_point_out = 'blocks.16.hook_resid_pre', + hook_point_in = 'blocks.8.hook_resid_pre', + hook_point_out = 'blocks.8.hook_resid_pre', use_decoder_bias = True, # Whether to use decoder bias. - expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. + expansion_factor = 128, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model. norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none". decoder_exactly_fixed_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm. use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder. @@ -69,13 +33,13 @@ l1_coefficient_warmup_steps = 10000, # The number of warm-up steps for the L1 regularization coefficient. lp = 1, # The p-norm to use for the L1 regularization. use_ghost_grads = False, # Whether to use the ghost gradients for saving dead features. - init_decoder_norm = 'auto', + init_decoder_norm = None, # The initial norm of the decoder. If None, the decoder will be initialized automatically with the lowest MSE. init_encoder_with_decoder_transpose = True, apply_decoder_bias_to_pre_encoder = True, sparsity_include_decoder_norm = True, # LanguageModelSAETrainingConfig - total_training_tokens = 300_000_000, # The total number of tokens to train the dictionary. + total_training_tokens = 100_000_000, # The total number of tokens to train the dictionary. lr = 1e-4, # The learning rate for the dictionary training. betas = (0.9, 0.9999), # The betas for the Adam optimizer. @@ -95,15 +59,15 @@ # WandbConfig log_to_wandb = True, # Whether to log the training information to wandb. - wandb_project= "llama3-sae", # The wandb project name. + wandb_project= "test", # The wandb project name. # RunnerConfig device = "cuda", # The device to place all torch tensors. seed = 42, # The random seed. - dtype = torch.bfloat16, # The torch data type of non-integer tensors. + dtype = torch.float32, # The torch data type of non-integer tensors. - exp_name = f"llama-L16RPr", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. - exp_series = "llama3-sae", + exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. + exp_series = "test", exp_result_dir = "results" )) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 00cf4ec..19935ab 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -174,8 +174,8 @@ class SAEConfig(BaseModelConfig): decoder_exactly_fixed_norm: bool = False sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm use_glu_encoder: bool = False - init_decoder_norm: float | str = 'auto' - init_encoder_norm: float | str | None = 'auto' + init_decoder_norm: float | None = None + init_encoder_norm: float | None = None init_encoder_with_decoder_transpose: bool = True l1_coefficient: float = 0.00008 @@ -196,10 +196,6 @@ def __post_init__(self): raise ValueError("sparsity_include_decoder_norm and decoder_exactly_fixed_norm are incompatible.") if self.sparsity_include_decoder_norm and self.use_ghost_grads: raise ValueError("sparsity_include_decoder_norm and use_ghost_grads are incompatible.") - if isinstance(self.init_decoder_norm, str) and self.init_decoder_norm != 'auto': - raise ValueError("init_decoder_norm must be a float or 'auto'.") - if isinstance(self.init_encoder_norm, str) and self.init_encoder_norm != 'auto': - raise ValueError("init_encoder_norm must be None, a float or 'auto'.") if self.init_encoder_with_decoder_transpose and isinstance(self.init_encoder_norm, float): raise ValueError("init_encoder_with_decoder_transpose and init_encoder_norm with float are incompatible.") diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index f993903..6c2d1a7 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -25,7 +25,7 @@ from lm_saes.sae import SparseAutoEncoder from lm_saes.activation.activation_dataset import make_activation_dataset from lm_saes.activation.activation_store import ActivationStore -from lm_saes.sae_training import prune_sae, train_sae, init_sae_on_dataset +from lm_saes.sae_training import prune_sae, train_sae from lm_saes.analysis.sample_feature_activations import sample_feature_activations from lm_saes.analysis.features_to_logits import features_to_logits @@ -67,13 +67,12 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): if ( cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None - or cfg.sae.init_decoder_norm == 'auto' + or cfg.sae.init_decoder_norm is None ): assert not cfg.finetuning - sae = init_sae_on_dataset( - model, - activation_store, - cfg + sae = SparseAutoEncoder.from_initialization_searching( + activation_store=activation_store, + cfg=cfg, ) else: sae = SparseAutoEncoder.from_config(cfg=cfg.sae) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 80dfaee..adb2f6c 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -9,7 +9,8 @@ import safetensors.torch as safe -from lm_saes.config import SAEConfig +from lm_saes.config import SAEConfig, LanguageModelSAETrainingConfig +from lm_saes.activation.activation_store import ActivationStore from lm_saes.utils.huggingface import parse_pretrained_name_or_path class SparseAutoEncoder(HookedRootModule): @@ -333,6 +334,8 @@ def update_l1_coefficient(self, training_step): @torch.no_grad() def set_decoder_norm_to_fixed_norm(self, value: float = 1.0, force_exact: bool | None = None): + if value is None: + return decoder_norm = torch.norm(self.decoder, dim=1, keepdim=True) if force_exact is None: force_exact = self.cfg.decoder_exactly_fixed_norm @@ -462,6 +465,77 @@ def from_pretrained( cfg = SAEConfig.from_pretrained(pretrained_name_or_path, strict_loading=strict_loading, **kwargs) return SparseAutoEncoder.from_config(cfg) + + @torch.no_grad() + @staticmethod + def from_initialization_searching( + activation_store: ActivationStore, + cfg: LanguageModelSAETrainingConfig, + ): + test_batch = activation_store.next(batch_size=cfg.train_batch_size * 8) # just random hard code xd + activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[ + cfg.sae.hook_point_out] # [batch, d_model] + + if cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None: + print(f'SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples.') + + average_in_norm, average_out_norm = activation_in.norm(p=2, dim=1).mean().item(), activation_out.norm(p=2, + dim=1).mean().item() + + print( + f'Average input activation norm: {average_in_norm}\nAverage output activation norm: {average_out_norm}') + cfg.sae.dataset_average_activation_norm = {'in': average_in_norm, 'out': average_out_norm} + + if cfg.sae.init_decoder_norm is None: + initializing_transcoder = False + assert cfg.sae.sparsity_include_decoder_norm, 'Decoder norm must be included in sparsity loss' + print('SAE: Starting grid search for initial decoder norm.') + if not cfg.sae.init_encoder_with_decoder_transpose: + assert cfg.sae.hook_point_in != cfg.sae.hook_point_out, 'If not training transcoder, it is recommended to init encoder with decoder transpose' + print('We are operating in 2-d grid search for encoder and decoder respectively.') + initializing_transcoder = True + + if not initializing_transcoder: + losses = {} + for norm in torch.linspace(0.1, 1, 10).numpy().tolist(): + cfg.sae.init_decoder_norm = norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[cfg.sae.init_decoder_norm] = mse + best_norm = min(losses, key=losses.get) + losses = {} + for norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): + cfg.sae.init_decoder_norm = best_norm + norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[cfg.sae.init_decoder_norm] = mse + best_norm = min(losses, key=losses.get) + print(f'The best (i.e. lowest MSE) initialized norm is {best_norm}') + cfg.sae.init_decoder_norm = best_norm + else: + losses = {} + for dec_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): + for enc_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): + cfg.sae.init_decoder_norm = dec_norm + cfg.sae.init_encoder_norm = enc_norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse + best_norms = min(losses, key=losses.get) + losses = {} + for dec_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): + for enc_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): + cfg.sae.init_decoder_norm = best_norms[0] + dec_norm + cfg.sae.init_encoder_norm = best_norms[1] + enc_norm + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() + losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse + best_norms = min(losses, key=losses.get) + print(f'The best (i.e. lowest MSE) initialized norms are (decoder, encoder): {best_norms}') + cfg.sae.init_decoder_norm = best_norms[0] + cfg.sae.init_encoder_norm = best_norms[1] + + return SparseAutoEncoder.from_config(cfg=cfg.sae) def save_pretrained( self, @@ -489,3 +563,4 @@ def decoder_norm(self): @property def encoder_norm(self): return torch.norm(self.encoder, p=2, dim=0).mean() + diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 40139c9..28ec0e4 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -342,75 +342,3 @@ def prune_sae( sae.save_pretrained(path) return sae - -@torch.no_grad() -def init_sae_on_dataset( - model: HookedTransformer, - activation_store: ActivationStore, - cfg: LanguageModelSAETrainingConfig, -): - test_batch = activation_store.next(batch_size=cfg.train_batch_size * 8) # just random hard code xd - activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[ - cfg.sae.hook_point_out] # [batch, d_model] - - if cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None: - print(f'SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples.') - - average_in_norm, average_out_norm = activation_in.norm(p=2, dim=1).mean().item(), activation_out.norm(p=2, - dim=1).mean().item() - - print(f'Average input activation norm: {average_in_norm}\nAverage output activation norm: {average_out_norm}') - cfg.sae.dataset_average_activation_norm = {'in': average_in_norm, 'out': average_out_norm} - - - if cfg.sae.init_decoder_norm == 'auto': - initializing_transcoder = False - assert cfg.sae.sparsity_include_decoder_norm, 'Decoder norm must be included in sparsity loss' - print('SAE: Starting grid search for initial decoder norm.') - if not cfg.sae.init_encoder_with_decoder_transpose: - assert cfg.sae.hook_point_in != cfg.sae.hook_point_out, 'If not training transcoder, it is recommended to init encoder with decoder transpose' - print('We are operating in 2-d grid search for encoder and decoder respectively.') - initializing_transcoder = True - - if not initializing_transcoder: - losses = {} - for norm in torch.linspace(0.1, 1, 10).numpy().tolist(): - cfg.sae.init_decoder_norm = norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[cfg.sae.init_decoder_norm] = mse - best_norm = min(losses, key=losses.get) - losses = {} - for norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): - cfg.sae.init_decoder_norm = best_norm + norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[cfg.sae.init_decoder_norm] = mse - best_norm = min(losses, key=losses.get) - print(f'The best (i.e. lowest MSE) initialized norm is {best_norm}') - cfg.sae.init_decoder_norm = best_norm - else: - losses = {} - for dec_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): - for enc_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): - cfg.sae.init_decoder_norm = dec_norm - cfg.sae.init_encoder_norm = enc_norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse - best_norms = min(losses, key=losses.get) - losses = {} - for dec_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): - for enc_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): - cfg.sae.init_decoder_norm = best_norms[0] + dec_norm - cfg.sae.init_encoder_norm = best_norms[1] + enc_norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse - best_norms = min(losses, key=losses.get) - print(f'The best (i.e. lowest MSE) initialized norms are (decoder, encoder): {best_norms}') - cfg.sae.init_decoder_norm = best_norms[0] - cfg.sae.init_encoder_norm = best_norms[1] - - return SparseAutoEncoder.from_config(cfg=cfg.sae) - From b0b13b0733596354e0f89d7a5a62c165f9e167a3 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Mon, 1 Jul 2024 20:19:08 +0800 Subject: [PATCH 3/7] fix(config): ignore mypy checking for norm init. --- src/lm_saes/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 19935ab..438835a 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -174,8 +174,8 @@ class SAEConfig(BaseModelConfig): decoder_exactly_fixed_norm: bool = False sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm use_glu_encoder: bool = False - init_decoder_norm: float | None = None - init_encoder_norm: float | None = None + init_decoder_norm: float | None = None # type: ignore + init_encoder_norm: float | None = None # type: ignore init_encoder_with_decoder_transpose: bool = True l1_coefficient: float = 0.00008 From 25a46895b404fd9980631e39d50cb0367a99b323 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Mon, 1 Jul 2024 20:25:39 +0800 Subject: [PATCH 4/7] fix --- src/lm_saes/sae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index adb2f6c..90b8af9 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -333,7 +333,7 @@ def update_l1_coefficient(self, training_step): self.current_l1_coefficient = min(1., training_step / self.cfg.l1_coefficient_warmup_steps) * self.cfg.l1_coefficient @torch.no_grad() - def set_decoder_norm_to_fixed_norm(self, value: float = 1.0, force_exact: bool | None = None): + def set_decoder_norm_to_fixed_norm(self, value: float | None = 1.0, force_exact: bool | None = None): if value is None: return decoder_norm = torch.norm(self.decoder, dim=1, keepdim=True) @@ -346,7 +346,7 @@ def set_decoder_norm_to_fixed_norm(self, value: float = 1.0, force_exact: bool | self.decoder.data = self.decoder.data * value / torch.clamp(decoder_norm, min=value) @torch.no_grad() - def set_encoder_norm_to_fixed_norm(self, value: float = 1.0): + def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): if self.cfg.use_glu_encoder: raise NotImplementedError("GLU encoder not supported") if value is None: From b84b799c506a3a448b2f10da2687e7fb841523c7 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Mon, 1 Jul 2024 20:28:18 +0800 Subject: [PATCH 5/7] fix --- src/lm_saes/sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 90b8af9..58ebbf5 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -118,7 +118,7 @@ def train_finetune_for_suppression_parameters(self): p.requires_grad_(True) - def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: + def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> float | torch.Tensor: """Compute the normalization factor for the activation vectors. """ From 6864795d9866d55fcd68381fe337ac122f0e170b Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Tue, 2 Jul 2024 14:03:42 +0800 Subject: [PATCH 6/7] fix(typing): fix mypy issues --- src/lm_saes/sae.py | 77 ++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 51 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 58ebbf5..9db2b3c 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -1,6 +1,6 @@ from importlib.metadata import version import os -from typing import Dict, Literal, Union, overload +from typing import Dict, Literal, Union, overload, List import torch import math from einops import einsum @@ -473,8 +473,7 @@ def from_initialization_searching( cfg: LanguageModelSAETrainingConfig, ): test_batch = activation_store.next(batch_size=cfg.train_batch_size * 8) # just random hard code xd - activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[ - cfg.sae.hook_point_out] # [batch, d_model] + activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore if cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None: print(f'SAE: Computing average activation norm on the first {cfg.train_batch_size * 8} samples.') @@ -487,55 +486,31 @@ def from_initialization_searching( cfg.sae.dataset_average_activation_norm = {'in': average_in_norm, 'out': average_out_norm} if cfg.sae.init_decoder_norm is None: - initializing_transcoder = False assert cfg.sae.sparsity_include_decoder_norm, 'Decoder norm must be included in sparsity loss' - print('SAE: Starting grid search for initial decoder norm.') - if not cfg.sae.init_encoder_with_decoder_transpose: - assert cfg.sae.hook_point_in != cfg.sae.hook_point_out, 'If not training transcoder, it is recommended to init encoder with decoder transpose' - print('We are operating in 2-d grid search for encoder and decoder respectively.') - initializing_transcoder = True - - if not initializing_transcoder: - losses = {} - for norm in torch.linspace(0.1, 1, 10).numpy().tolist(): - cfg.sae.init_decoder_norm = norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[cfg.sae.init_decoder_norm] = mse - best_norm = min(losses, key=losses.get) - losses = {} - for norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): - cfg.sae.init_decoder_norm = best_norm + norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[cfg.sae.init_decoder_norm] = mse - best_norm = min(losses, key=losses.get) - print(f'The best (i.e. lowest MSE) initialized norm is {best_norm}') - cfg.sae.init_decoder_norm = best_norm - else: - losses = {} - for dec_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): - for enc_norm in torch.linspace(0.1, 1, 10).numpy().tolist(): - cfg.sae.init_decoder_norm = dec_norm - cfg.sae.init_encoder_norm = enc_norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse - best_norms = min(losses, key=losses.get) - losses = {} - for dec_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): - for enc_norm in torch.linspace(-0.09, 0.1, 20).numpy().tolist(): - cfg.sae.init_decoder_norm = best_norms[0] + dec_norm - cfg.sae.init_encoder_norm = best_norms[1] + enc_norm - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - mse = sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() - losses[(cfg.sae.init_decoder_norm, cfg.sae.init_encoder_norm)] = mse - best_norms = min(losses, key=losses.get) - print(f'The best (i.e. lowest MSE) initialized norms are (decoder, encoder): {best_norms}') - cfg.sae.init_decoder_norm = best_norms[0] - cfg.sae.init_encoder_norm = best_norms[1] - - return SparseAutoEncoder.from_config(cfg=cfg.sae) + if not cfg.sae.init_encoder_with_decoder_transpose or cfg.sae.hook_point_in != cfg.sae.hook_point_out: + raise NotImplementedError('Transcoders cannot be initialized automatically.') + print('SAE: Starting grid search for initial decoder norm.') + + test_sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + def grid_search_best_init_norm(search_range: List[float]) -> float: + losses: Dict[float, float] = {} + for norm in search_range: + test_sae.set_decoder_norm_to_fixed_norm(norm, force_exact=True) + test_sae.encoder.data = test_sae.decoder.data.T.clone().contiguous() + mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]['l_rec'].mean().item() # type: ignore + losses[norm] = mse + best_norm = min(losses, key=losses.get) # type: ignore + return best_norm + + best_norm_coarse = grid_search_best_init_norm(torch.linspace(0.1, 1, 10).numpy().tolist()) + best_norm_fine_grained = grid_search_best_init_norm(torch.linspace(best_norm_coarse - 0.09, best_norm_coarse + 0.1, 20).numpy().tolist()) + print(f'The best (i.e. lowest MSE) initialized norm is {best_norm_fine_grained}') + + test_sae.set_decoder_norm_to_fixed_norm(best_norm_fine_grained, force_exact=True) + test_sae.encoder.data = test_sae.decoder.data.T.clone().contiguous() + + return test_sae def save_pretrained( self, From b2d23803eff9794ddcb4ab8c69fd1841f47f39c9 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Tue, 2 Jul 2024 17:20:45 +0800 Subject: [PATCH 7/7] feature(activation): add an option to prepend bos before each sentence. Recommend to be True except GPT2 SAEs --- src/lm_saes/activation/activation_source.py | 2 +- src/lm_saes/activation/token_source.py | 9 ++++++--- src/lm_saes/config.py | 5 +++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/activation/activation_source.py b/src/lm_saes/activation/activation_source.py index ce8a01c..d0b5d40 100644 --- a/src/lm_saes/activation/activation_source.py +++ b/src/lm_saes/activation/activation_source.py @@ -40,7 +40,7 @@ def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig): self.cfg = cfg def next(self) -> Dict[str, torch.Tensor] | None: - tokens = self.token_source.next(self.cfg.dataset.store_batch_size) + tokens = self.next_tokens(self.cfg.dataset.store_batch_size) if tokens is None: return None diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index 2cb7f5e..f825e54 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -17,6 +17,7 @@ def __init__( concat_tokens: list[bool], seq_len: int, sample_probs: list[float], + prepend_bos: list[bool] ): self.dataloader = dataloader self.model = model @@ -33,13 +34,14 @@ def __init__( self.resid = torch.tensor([], dtype=torch.long, device=self.device) self.sample_probs = sample_probs + self.prepend_bos = prepend_bos - def fill_with_one_batch(self, batch, pack) -> None: + def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None: if self.is_dataset_tokenized: tokens: torch.Tensor = batch["tokens"].to(self.device) else: - tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device) + tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device) if pack: while tokens.size(0) > 0: cur_tokens = tokens[0] @@ -81,7 +83,7 @@ def next(self, batch_size: int) -> torch.Tensor | None: else: return None - self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch]) + self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch], prepend_bos=self.prepend_bos[dataset_idx_to_fetch]) ret = self.token_buffer[:batch_size] self.token_buffer = self.token_buffer[batch_size:] @@ -120,4 +122,5 @@ def from_config(model: HookedTransformer, cfg: TextDatasetConfig): concat_tokens=cfg.concat_tokens, seq_len=cfg.context_size, sample_probs=cfg.sample_probs, + prepend_bos=cfg.prepend_bos ) \ No newline at end of file diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 438835a..808719e 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -106,6 +106,7 @@ class TextDatasetConfig(RunnerConfig): context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) + prepend_bos: List[bool] = field(default_factory=lambda: [False]) def __post_init__(self): super().__post_init__() @@ -115,10 +116,14 @@ def __post_init__(self): if isinstance(self.concat_tokens, bool): self.concat_tokens = [self.concat_tokens] + if isinstance(self.prepend_bos, bool): + self.prepend_bos = [self.prepend_bos] + self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs] assert len(self.sample_probs) == len(self.dataset_path), "Number of sample_probs must match number of dataset paths" assert len(self.concat_tokens) == len(self.dataset_path), "Number of concat_tokens must match number of dataset paths" + assert len(self.prepend_bos) == len(self.dataset_path), "Number of prepend_bos must match number of dataset paths" @dataclass(kw_only=True)