diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 75b50c3..8048346 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -394,6 +394,8 @@ def __post_init__(self): assert 0 <= self.lr_cool_down_steps <= 1.0 self.lr_cool_down_steps = int(self.lr_cool_down_steps * total_training_steps) print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}") + if self.finetuning: + assert self.sae.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning." @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 1630057..1f011ab 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -140,12 +140,10 @@ def train_base_parameters(self): p.requires_grad_(True) def train_finetune_for_suppression_parameters(self): - """Set the parameters to be trained for feature suppression.""" + """Set the parameters to be trained against feature suppression.""" + + finetune_for_suppression_parameters = [self.decoder.weight] - finetune_for_suppression_parameters = [ - self.feature_act_scale, - self.decoder.weight, - ] if self.cfg.use_decoder_bias: finetune_for_suppression_parameters.append(self.decoder.bias) for p in self.parameters(): diff --git a/src/lm_saes/utils/huggingface.py b/src/lm_saes/utils/huggingface.py index d51bae0..7807d4f 100644 --- a/src/lm_saes/utils/huggingface.py +++ b/src/lm_saes/utils/huggingface.py @@ -4,6 +4,7 @@ import os import shutil from huggingface_hub import create_repo, upload_folder, snapshot_download +from lm_saes.utils.misc import print_once def upload_pretrained_sae_to_hf(sae_path: str, repo_id: str, private: bool = False): @@ -54,6 +55,7 @@ def parse_pretrained_name_or_path(pretrained_name_or_path: str): if os.path.exists(pretrained_name_or_path): return pretrained_name_or_path else: + print_once(f'Local path `{pretrained_name_or_path}` not found. Downloading from huggingface model hub.') repo_id = "/".join(pretrained_name_or_path.split("/")[:2]) hook_point = "/".join(pretrained_name_or_path.split("/")[2:]) return download_pretrained_sae_from_hf(repo_id, hook_point) \ No newline at end of file