diff --git a/examples/configuration/analyze.toml b/examples/configuration/analyze.toml index 76d11c48..0fe81dfd 100644 --- a/examples/configuration/analyze.toml +++ b/examples/configuration/analyze.toml @@ -11,10 +11,10 @@ exp_result_dir = "results" [subsample] "top_activations" = { "proportion" = 1.0, "n_samples" = 80 } -"subsample-0.9" = { "proportion" = 0.9, "n_samples" = 20} -"subsample-0.8" = { "proportion" = 0.8, "n_samples" = 20} -"subsample-0.7" = { "proportion" = 0.7, "n_samples" = 20} -"subsample-0.5" = { "proportion" = 0.5, "n_samples" = 20} +"subsample-0.9" = { "proportion" = 0.9, "n_samples" = 20 } +"subsample-0.8" = { "proportion" = 0.8, "n_samples" = 20 } +"subsample-0.7" = { "proportion" = 0.7, "n_samples" = 20 } +"subsample-0.5" = { "proportion" = 0.5, "n_samples" = 20 } [lm] model_name = "gpt2" diff --git a/examples/configuration/prune.toml b/examples/configuration/prune.toml new file mode 100644 index 00000000..231bcaf6 --- /dev/null +++ b/examples/configuration/prune.toml @@ -0,0 +1,40 @@ +use_ddp = false +device = "cuda" +seed = 42 +dtype = "torch.float32" + +exp_name = "L3M" +exp_series = "default" +exp_result_dir = "results" + +total_training_tokens = 10_000_000 +train_batch_size = 4096 + +dead_feature_threshold = 1e-6 +dead_feature_max_act_threshold = 1.0 +decoder_norm_threshold = 0.99 + +[lm] +model_name = "gpt2" +d_model = 768 + +[dataset] +dataset_path = "openwebtext" +is_dataset_tokenized = false +is_dataset_on_disk = false +concat_tokens = false +context_size = 256 +store_batch_size = 32 + +[act_store] +device = "cuda" +seed = 42 +dtype = "torch.float32" +hook_points = [ "blocks.3.hook_mlp_out",] +use_cached_activations = false +n_tokens_in_buffer = 500000 + +[wandb] +log_to_wandb = true +wandb_project = "gpt2-sae" +wandb_entity = "fnlp-mechinterp" \ No newline at end of file diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 9e51230e..785e6c22 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -322,6 +322,12 @@ class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): dead_feature_max_act_threshold: float = 1.0 decoder_norm_threshold: float = 0.99 + def __post_init__(self): + super().__post_init__() + + if not self.use_ddp or self.rank == 0: + os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), exist_ok=True) + @dataclass(kw_only=True) class ActivationGenerationConfig(RunnerConfig): diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index b4294b54..c13eb9af 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -107,6 +107,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( (