From e97eb60e99ff411a38accd3affb8e986c908e023 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 31 Jul 2024 12:18:42 +0800 Subject: [PATCH 1/6] feat(sae): add a utils func to merge pre-enc bias into enc bias --- src/lm_saes/utils/convert_pre_enc_bias.py | 12 ++++++++++++ tests/unit/test_convert_pre_enc_bias.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 src/lm_saes/utils/convert_pre_enc_bias.py create mode 100644 tests/unit/test_convert_pre_enc_bias.py diff --git a/src/lm_saes/utils/convert_pre_enc_bias.py b/src/lm_saes/utils/convert_pre_enc_bias.py new file mode 100644 index 0000000..03e7b6f --- /dev/null +++ b/src/lm_saes/utils/convert_pre_enc_bias.py @@ -0,0 +1,12 @@ +from lm_saes.sae import SparseAutoEncoder +import torch + + +@torch.no_grad() +def merge_pre_enc_bias_to_enc_bias(sae: SparseAutoEncoder): + assert sae.cfg.apply_decoder_bias_to_pre_encoder + + sae.cfg.apply_decoder_bias_to_pre_encoder = False + sae.encoder.bias.data = sae.encoder.bias.data - sae.encoder.weight.data @ sae.decoder.bias.data + + return sae \ No newline at end of file diff --git a/tests/unit/test_convert_pre_enc_bias.py b/tests/unit/test_convert_pre_enc_bias.py new file mode 100644 index 0000000..f5cbb92 --- /dev/null +++ b/tests/unit/test_convert_pre_enc_bias.py @@ -0,0 +1,15 @@ +from lm_saes.sae import SparseAutoEncoder +from lm_saes.config import SAEConfig +from lm_saes.utils.convert_pre_enc_bias import merge_pre_enc_bias_to_enc_bias +import torch + +cfg = SAEConfig( + d_model=512, + expansion_factor=4, + apply_decoder_bias_to_pre_encoder=True, +) + +sae = SparseAutoEncoder(cfg) +sample = torch.randn(4, cfg.d_model) + +assert (sae(sample) == merge_pre_enc_bias_to_enc_bias(sae)(sample)).all() \ No newline at end of file From fd6c6ed11cfe841b686305788f63cc41723f793c Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 31 Jul 2024 14:30:11 +0800 Subject: [PATCH 2/6] fix(sae): do not init device mesh in single device mode --- src/lm_saes/sae.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index d238422..278211e 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -60,9 +60,10 @@ def __init__(self, cfg: SAEConfig): ) torch.nn.init.kaiming_uniform_(self.encoder.weight) torch.nn.init.zeros_(self.encoder.bias) - self.device_mesh = init_device_mesh( - "cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp") - ) + if cfg.tp_size > 1 or cfg.ddp_size > 1: + self.device_mesh = init_device_mesh( + "cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp") + ) if cfg.use_glu_encoder: From 9934a6ad081959dd0a93c3c073e46e10c21339cb Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Sun, 4 Aug 2024 10:21:09 +0800 Subject: [PATCH 3/6] feat(ft4supp): support ft4supp adjusted for AprilTrick update SAEs --- src/lm_saes/config.py | 12 ++---------- src/lm_saes/runner.py | 16 ++++------------ src/lm_saes/sae.py | 8 +++----- 3 files changed, 9 insertions(+), 27 deletions(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 75b50c3..567420d 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -394,6 +394,8 @@ def __post_init__(self): assert 0 <= self.lr_cool_down_steps <= 1.0 self.lr_cool_down_steps = int(self.lr_cool_down_steps * total_training_steps) print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}") + if self.finetuning: + assert self.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning." @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): @@ -471,16 +473,6 @@ class LanguageModelSAEAnalysisConfig(RunnerConfig): } ) - n_sae_chunks: int = ( - 1 # Number of chunks to split the SAE into for analysis. For large models and SAEs, this can be useful to avoid memory issues. - ) - - def __post_init__(self): - super().__post_init__() - assert ( - self.sae.d_sae % self.n_sae_chunks == 0 - ), f"d_sae ({self.sae.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})" - @dataclass(kw_only=True) class FeaturesDecoderConfig(RunnerConfig): diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 0e4b598..00f4c55 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -50,11 +50,6 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): if is_master(): cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() hf_model = AutoModelForCausalLM.from_pretrained( ( @@ -90,13 +85,10 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - - - if ( + if not cfg.finetuning and ( cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None or cfg.sae.init_decoder_norm is None ): - assert not cfg.finetuning sae = SparseAutoEncoder.from_initialization_searching( activation_store=activation_store, cfg=cfg, @@ -108,8 +100,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): # Fine-tune SAE with frozen encoder weights and bias sae.train_finetune_for_suppression_parameters() - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if is_master(): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { @@ -399,7 +392,6 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): del activation_store torch.cuda.empty_cache() - @torch.no_grad() def features_to_logits_runner(cfg: FeaturesDecoderConfig): sae = SparseAutoEncoder.from_config(cfg=cfg.sae) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 278211e..567379a 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -137,12 +137,10 @@ def train_base_parameters(self): p.requires_grad_(True) def train_finetune_for_suppression_parameters(self): - """Set the parameters to be trained for feature suppression.""" + """Set the parameters to be trained against feature suppression.""" + + finetune_for_suppression_parameters = [self.decoder.weight] - finetune_for_suppression_parameters = [ - self.feature_act_scale, - self.decoder.weight, - ] if self.cfg.use_decoder_bias: finetune_for_suppression_parameters.append(self.decoder.bias) for p in self.parameters(): From f0aa9c9c07b3093a0491505defeac31f2cbf5a12 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Sun, 4 Aug 2024 10:25:27 +0800 Subject: [PATCH 4/6] feat(ft4supp): support ft4supp adjusted for AprilTrick update SAEs --- src/lm_saes/runner.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 395d3fc..aaab27c 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -32,6 +32,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP from lm_saes.utils.misc import is_master +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + loss_parallel, +) +from torch.distributed._tensor import ( + DTensor, + Shard, + Replicate, + distribute_module, + distribute_tensor, +) + def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): if is_master(): @@ -304,6 +317,20 @@ def activation_generation_runner(cfg: ActivationGenerationConfig): def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + if cfg.sae.tp_size > 1: + plan = { + "encoder": ColwiseParallel(output_layouts=Replicate()), + } + if cfg.sae.use_glu_encoder: + plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) + sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore + sae.parallelize_plan = plan + + sae.decoder.weight = None + torch.cuda.empty_cache() + + + hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name From 3efe943bb9cf204e1eadc6b9255fc10ec24d686f Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Sun, 4 Aug 2024 10:31:46 +0800 Subject: [PATCH 5/6] fix(misc): remove unnecessary changes --- src/lm_saes/config.py | 10 ++++++++++ src/lm_saes/runner.py | 1 - 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 567420d..9f36a3b 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -473,6 +473,16 @@ class LanguageModelSAEAnalysisConfig(RunnerConfig): } ) + n_sae_chunks: int = ( + 1 # Number of chunks to split the SAE into for analysis. For large models and SAEs, this can be useful to avoid memory issues. + ) + + def __post_init__(self): + super().__post_init__() + assert ( + self.sae.d_sae % self.n_sae_chunks == 0 + ), f"d_sae ({self.sae.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})" + @dataclass(kw_only=True) class FeaturesDecoderConfig(RunnerConfig): diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index e03bfaa..eb5d4e8 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -309,7 +309,6 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore sae.parallelize_plan = plan - sae.decoder.weight = None # type: ignore[assignment] torch.cuda.empty_cache() From f80cb0ea2646089129be1a04819608a59242957a Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Sun, 4 Aug 2024 12:19:11 +0800 Subject: [PATCH 6/6] fix(ft4supp): supp final ver. --- src/lm_saes/config.py | 2 +- src/lm_saes/utils/huggingface.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 9f36a3b..8048346 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -395,7 +395,7 @@ def __post_init__(self): self.lr_cool_down_steps = int(self.lr_cool_down_steps * total_training_steps) print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}") if self.finetuning: - assert self.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning." + assert self.sae.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning." @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): diff --git a/src/lm_saes/utils/huggingface.py b/src/lm_saes/utils/huggingface.py index d51bae0..7807d4f 100644 --- a/src/lm_saes/utils/huggingface.py +++ b/src/lm_saes/utils/huggingface.py @@ -4,6 +4,7 @@ import os import shutil from huggingface_hub import create_repo, upload_folder, snapshot_download +from lm_saes.utils.misc import print_once def upload_pretrained_sae_to_hf(sae_path: str, repo_id: str, private: bool = False): @@ -54,6 +55,7 @@ def parse_pretrained_name_or_path(pretrained_name_or_path: str): if os.path.exists(pretrained_name_or_path): return pretrained_name_or_path else: + print_once(f'Local path `{pretrained_name_or_path}` not found. Downloading from huggingface model hub.') repo_id = "/".join(pretrained_name_or_path.split("/")[:2]) hook_point = "/".join(pretrained_name_or_path.split("/")[2:]) return download_pretrained_sae_from_hf(repo_id, hook_point) \ No newline at end of file