diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 17f89a3e..0224eee1 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -81,11 +81,10 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - if ( + if not cfg.finetuning and ( cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None or cfg.sae.init_decoder_norm is None ): - assert not cfg.finetuning sae = SparseAutoEncoder.from_initialization_searching( activation_store=activation_store, cfg=cfg, @@ -97,8 +96,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): # Fine-tune SAE with frozen encoder weights and bias sae.train_finetune_for_suppression_parameters() - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + if is_master(): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = {