diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 8048346..9fe40c6 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -106,7 +106,7 @@ class TextDatasetConfig(RunnerConfig): context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) - prepend_bos: List[bool] = field(default_factory=lambda: [False]) + prepend_bos: List[bool] = field(default_factory=lambda: [True]) def __post_init__(self): super().__post_init__() @@ -119,6 +119,9 @@ def __post_init__(self): if isinstance(self.prepend_bos, bool): self.prepend_bos = [self.prepend_bos] + if False in self.prepend_bos: + print('Warning: prepend_bos is set to False for some datasets. This setting might not be suitable for most modern models.') + self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs] assert len(self.sample_probs) == len( diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index eb5d4e8..0224eee 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -47,15 +47,6 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - if is_master(): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() - hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -86,11 +77,28 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - model.offload_params_after(cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device)) + model.offload_params_after(cfg.act_store.hook_points[-1], torch.tensor([[0]], device=cfg.lm.device)) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + if not cfg.finetuning and ( + cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None + or cfg.sae.init_decoder_norm is None + ): + sae = SparseAutoEncoder.from_initialization_searching( + activation_store=activation_store, + cfg=cfg, + ) + else: + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + if is_master(): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = {