From 6245397f90a8db9c7e8cd10f997c92729d4872d2 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Fri, 9 Aug 2024 18:24:31 +0800 Subject: [PATCH] fix(runner): add accidentally missing from_init_searching --- src/lm_saes/runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 17f89a3..0224eee 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -81,11 +81,10 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - if ( + if not cfg.finetuning and ( cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None or cfg.sae.init_decoder_norm is None ): - assert not cfg.finetuning sae = SparseAutoEncoder.from_initialization_searching( activation_store=activation_store, cfg=cfg, @@ -97,8 +96,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): # Fine-tune SAE with frozen encoder weights and bias sae.train_finetune_for_suppression_parameters() - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if is_master(): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = {