diff --git a/examples/configuration/analyze.toml b/examples/configuration/analyze.toml index eec1c74..09319c9 100644 --- a/examples/configuration/analyze.toml +++ b/examples/configuration/analyze.toml @@ -7,7 +7,7 @@ dtype = "torch.float32" exp_name = "L3M" exp_series = "default" -exp_result_dir = "results" +exp_result_path = "results/L3M" [subsample] "top_activations" = { "proportion" = 1.0, "n_samples" = 80 } diff --git a/examples/configuration/prune.toml b/examples/configuration/prune.toml index f51991b..0c55276 100644 --- a/examples/configuration/prune.toml +++ b/examples/configuration/prune.toml @@ -5,7 +5,7 @@ dtype = "torch.float32" exp_name = "L3M" exp_series = "default" -exp_result_dir = "results" +exp_result_path = "results/L3M" total_training_tokens = 10_000_000 train_batch_size = 4096 diff --git a/examples/configuration/train.toml b/examples/configuration/train.toml index dec963a..b22c9c8 100644 --- a/examples/configuration/train.toml +++ b/examples/configuration/train.toml @@ -1,6 +1,6 @@ use_ddp = false exp_name = "L3M" -exp_result_dir = "results" +exp_result_path = "results/L3M" device = "cuda" seed = 42 dtype = "torch.float32" diff --git a/examples/programmatic/analyze.py b/examples/programmatic/analyze.py index d60f259..cac1eaa 100644 --- a/examples/programmatic/analyze.py +++ b/examples/programmatic/analyze.py @@ -41,7 +41,7 @@ exp_name = "L3M", exp_series = "default", - exp_result_dir = "results", + exp_result_path = "results/L3M", )) sample_feature_activations_runner(cfg) \ No newline at end of file diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index 0f433a6..80e60ce 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -69,7 +69,7 @@ exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. exp_series = "test", - exp_result_dir = "results" + exp_result_path = "results/test" )) sparse_autoencoder = language_model_sae_runner(cfg) \ No newline at end of file diff --git a/server/app.py b/server/app.py index a32fc1e..cb29624 100644 --- a/server/app.py +++ b/server/app.py @@ -43,7 +43,10 @@ def get_model(dictionary_name: str) -> HookedTransformer: - cfg = LanguageModelConfig.from_pretrained_sae(f"{result_dir}/{dictionary_name}") + path = client.get_dictionary_path(dictionary_name, dictionary_series=dictionary_series) + if path is "": + path = f"{result_dir}/{dictionary_name}" + cfg = LanguageModelConfig.from_pretrained_sae(path) if (cfg.model_name, cfg.model_from_pretrained_path) not in lm_cache: hf_model = AutoModelForCausalLM.from_pretrained( ( diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 75b50c3..606d695 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -53,13 +53,12 @@ def __post_init__(self): class RunnerConfig(BaseConfig): exp_name: str = "test" exp_series: Optional[str] = None - exp_result_dir: str = "results" + exp_result_path: str = "results" def __post_init__(self): super().__post_init__() if is_master(): - os.makedirs(self.exp_result_dir, exist_ok=True) - os.makedirs(os.path.join(self.exp_result_dir, self.exp_name), exist_ok=True) + os.makedirs(self.exp_result_path, exist_ok=True) @dataclass(kw_only=True) @@ -254,12 +253,12 @@ def from_pretrained( @deprecated("Use from_pretrained and to_dict instead.") @staticmethod def get_hyperparameters( - exp_name: str, exp_result_dir: str, ckpt_name: str, strict_loading: bool = True + exp_result_path: str, ckpt_name: str, strict_loading: bool = True ) -> dict[str, Any]: - with open(os.path.join(exp_result_dir, exp_name, "hyperparams.json"), "r") as f: + with open(os.path.join(exp_result_path, "hyperparams.json"), "r") as f: hyperparams = json.load(f) hyperparams["sae_pretrained_name_or_path"] = os.path.join( - exp_result_dir, exp_name, "checkpoints", ckpt_name + exp_result_path, "checkpoints", ckpt_name ) hyperparams["strict_loading"] = strict_loading # Remove non-hyperparameters from the dict @@ -348,12 +347,12 @@ def __post_init__(self): if is_master(): if os.path.exists( - os.path.join(self.exp_result_dir, self.exp_name, "checkpoints") + os.path.join(self.exp_result_path, "checkpoints") ): raise ValueError( - f"Checkpoints for experiment {self.exp_name} already exist. Consider changing the experiment name." + f"Checkpoints for experiment {self.exp_result_path} already exist. Consider changing the experiment name." ) - os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints")) + os.makedirs(os.path.join(self.exp_result_path, "checkpoints")) self.effective_batch_size = self.train_batch_size * self.sae.ddp_size print_once(f"Effective batch size: {self.effective_batch_size}") @@ -413,7 +412,7 @@ def __post_init__(self): if is_master(): os.makedirs( - os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), + os.path.join(self.exp_result_path, "checkpoints"), exist_ok=True, ) diff --git a/src/lm_saes/database.py b/src/lm_saes/database.py index 75f55b3..8a92484 100644 --- a/src/lm_saes/database.py +++ b/src/lm_saes/database.py @@ -58,8 +58,8 @@ def _remove_gridfs_objs(self, data): if isinstance(data, ObjectId) and self.fs.exists(data): self.fs.delete(data) - def create_dictionary(self, dictionary_name: str, n_features: int, dictionary_series: str | None = None): - dict_id = self.dictionary_collection.insert_one({'name': dictionary_name, 'n_features': n_features, 'series': dictionary_series}).inserted_id + def create_dictionary(self, dictionary_name: str, dictionary_path: str, n_features: int, dictionary_series: str | None = None): + dict_id = self.dictionary_collection.insert_one({'name': dictionary_name, 'n_features': n_features, 'series': dictionary_series, 'path': dictionary_path}).inserted_id self.feature_collection.insert_many([ { 'dictionary_id': dict_id, @@ -85,6 +85,7 @@ def update_feature(self, dictionary_name: str, feature_index: int, feature_data: self.feature_collection.update_one({'_id': feature['_id']}, {'$set': self._to_gridfs(feature_data)}) def list_dictionaries(self, dictionary_series: str | None = None): + # return [{'name': d['name'], 'path': d['path']} for d in self.dictionary_collection.find({'series': dictionary_series} if dictionary_series is not None else {})] return [d['name'] for d in self.dictionary_collection.find({'series': dictionary_series} if dictionary_series is not None else {})] def get_feature(self, dictionary_name: str, feature_index: int, dictionary_series: str | None = None): @@ -200,3 +201,8 @@ def get_attn_head(self, dictionary_name: str, head_index: int, dictionary_series ] return self.attn_head_collection.aggregate(pipeline).next() + def get_dictionary_path(self, dictionary_name: str, dictionary_series: str | None = None): + dictionary = self.dictionary_collection.find_one({'name': dictionary_name, 'series': dictionary_series}) + if dictionary is None: + return None + return dictionary['path'] \ No newline at end of file diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index bbe8776..088099b 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -49,8 +49,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): if is_master(): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_path)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_path)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) if cfg.finetuning: @@ -87,12 +87,12 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - model.offload_params_after(cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device)) + model.offload_params_after( + cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device) + ) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - - if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { **asdict(cfg), @@ -108,7 +108,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): entity=cfg.wandb.wandb_entity, ) with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "train_wandb_id.txt"), "w" + os.path.join(cfg.exp_result_path, "train_wandb_id.txt"), "w" ) as f: f.write(wandb_run.id) wandb.watch(sae, log="all") @@ -128,8 +128,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_path)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_path)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( ( @@ -177,7 +177,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): entity=cfg.wandb.wandb_entity, ) with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "prune_wandb_id.txt"), "w" + os.path.join(cfg.exp_result_path, "prune_wandb_id.txt"), "w" ) as f: f.write(wandb_run.id) @@ -229,7 +229,9 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - model.offload_params_after(cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device)) + model.offload_params_after( + cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device) + ) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) @@ -248,7 +250,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): entity=cfg.wandb.wandb_entity, ) with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "eval_wandb_id.txt"), "w" + os.path.join(cfg.exp_result_path, "eval_wandb_id.txt"), "w" ) as f: f.write(wandb_run.id) @@ -309,14 +311,12 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): } if cfg.sae.use_glu_encoder: plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) - sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore + sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore sae.parallelize_plan = plan sae.decoder.weight = None # type: ignore[assignment] torch.cuda.empty_cache() - - hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -348,7 +348,9 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): model.eval() client = MongoClient(cfg.mongo.mongo_uri, cfg.mongo.mongo_db) if is_master(): - client.create_dictionary(cfg.exp_name, cfg.sae.d_sae, cfg.exp_series) + client.create_dictionary( + cfg.exp_name, cfg.exp_result_path, cfg.sae.d_sae, cfg.exp_series + ) for chunk_id in range(cfg.n_sae_chunks): activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 93045de..5629067 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -318,8 +318,7 @@ def train_sae( ): # Save the model and optimizer state path = os.path.join( - cfg.exp_result_dir, - cfg.exp_name, + cfg.exp_result_path, "checkpoints", f"{n_training_steps}.safetensors", ) @@ -345,7 +344,7 @@ def train_sae( if not cfg.sae.sparsity_include_decoder_norm: sae.set_decoder_norm_to_fixed_norm(1) path = os.path.join( - cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors" + cfg.exp_result_path, "checkpoints", "final.safetensors" ) sae.save_pretrained(path) @@ -447,7 +446,7 @@ def prune_sae( print("Total pruned features:", (sae.feature_act_mask == 0).sum().item()) path = os.path.join( - cfg.exp_result_dir, cfg.exp_name, "checkpoints", "pruned.safetensors" + cfg.exp_result_path, "checkpoints", "pruned.safetensors" ) sae.save_pretrained(path)