Skip to content

Commit

Permalink
Merge pull request #47 from OpenMOSS/ft4supp
Browse files Browse the repository at this point in the history
Ft4supp
  • Loading branch information
Hzfinfdu authored Aug 9, 2024
2 parents 4369347 + 6245397 commit 9f6ccfb
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
if is_master():
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down Expand Up @@ -86,11 +77,28 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)
model.offload_params_after(cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device))
model.offload_params_after(cfg.act_store.hook_points[-1], torch.tensor([[0]], device=cfg.lm.device))
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if not cfg.finetuning and (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

if is_master():
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and is_master():
wandb_config: dict = {
Expand Down

0 comments on commit 9f6ccfb

Please sign in to comment.