Skip to content

Commit

Permalink
fix(runner): add accidentally missing from_init_searching
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfinfdu committed Aug 9, 2024
1 parent c8e32d0 commit 9d598a2
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
if is_master():
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down Expand Up @@ -90,7 +81,24 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
assert not cfg.finetuning
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and is_master():
wandb_config: dict = {
Expand Down

0 comments on commit 9d598a2

Please sign in to comment.