From 5ef6d096a2db4507a1c7eda9168c52835a8a1a30 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 7 Aug 2024 23:44:55 +0800 Subject: [PATCH 1/4] fix(textdataset): set default prepend_bos to True --- src/lm_saes/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 567420d..aa84e52 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -106,7 +106,7 @@ class TextDatasetConfig(RunnerConfig): context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) - prepend_bos: List[bool] = field(default_factory=lambda: [False]) + prepend_bos: List[bool] = field(default_factory=lambda: [True]) def __post_init__(self): super().__post_init__() @@ -119,6 +119,9 @@ def __post_init__(self): if isinstance(self.prepend_bos, bool): self.prepend_bos = [self.prepend_bos] + if False in self.prepend_bos: + print('Warning: prepend_bos is set to False for some datasets. This setting might not be suitable for most modern models.') + self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs] assert len(self.sample_probs) == len( From c8e32d092c6a98dc15c1de0cd8bac38221a97232 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Fri, 9 Aug 2024 13:11:56 +0800 Subject: [PATCH 2/4] fix(runner): set offload after the last hook (previously it was the first) --- src/lm_saes/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index eb5d4e8..53c6469 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -86,7 +86,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - model.offload_params_after(cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device)) + model.offload_params_after(cfg.act_store.hook_points[-1], torch.tensor([[0]], device=cfg.lm.device)) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) From 9d598a2f560b7e16feb24fca76fc218c62f26ff7 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Fri, 9 Aug 2024 18:22:41 +0800 Subject: [PATCH 3/4] fix(runner): add accidentally missing from_init_searching --- src/lm_saes/runner.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 53c6469..17f89a3 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -47,15 +47,6 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - if is_master(): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() - hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -90,7 +81,24 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + if ( + cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None + or cfg.sae.init_decoder_norm is None + ): + assert not cfg.finetuning + sae = SparseAutoEncoder.from_initialization_searching( + activation_store=activation_store, + cfg=cfg, + ) + else: + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { From 6245397f90a8db9c7e8cd10f997c92729d4872d2 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Fri, 9 Aug 2024 18:24:31 +0800 Subject: [PATCH 4/4] fix(runner): add accidentally missing from_init_searching --- src/lm_saes/runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 17f89a3..0224eee 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -81,11 +81,10 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - if ( + if not cfg.finetuning and ( cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None or cfg.sae.init_decoder_norm is None ): - assert not cfg.finetuning sae = SparseAutoEncoder.from_initialization_searching( activation_store=activation_store, cfg=cfg, @@ -97,8 +96,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): # Fine-tune SAE with frozen encoder weights and bias sae.train_finetune_for_suppression_parameters() - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if is_master(): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = {