diff --git a/examples/configuration/analyze.toml b/examples/configuration/analyze.toml index eec1c74..09319c9 100644 --- a/examples/configuration/analyze.toml +++ b/examples/configuration/analyze.toml @@ -7,7 +7,7 @@ dtype = "torch.float32" exp_name = "L3M" exp_series = "default" -exp_result_dir = "results" +exp_result_path = "results/L3M" [subsample] "top_activations" = { "proportion" = 1.0, "n_samples" = 80 } diff --git a/examples/configuration/prune.toml b/examples/configuration/prune.toml index f51991b..0c55276 100644 --- a/examples/configuration/prune.toml +++ b/examples/configuration/prune.toml @@ -5,7 +5,7 @@ dtype = "torch.float32" exp_name = "L3M" exp_series = "default" -exp_result_dir = "results" +exp_result_path = "results/L3M" total_training_tokens = 10_000_000 train_batch_size = 4096 diff --git a/examples/configuration/train.toml b/examples/configuration/train.toml index dec963a..b22c9c8 100644 --- a/examples/configuration/train.toml +++ b/examples/configuration/train.toml @@ -1,6 +1,6 @@ use_ddp = false exp_name = "L3M" -exp_result_dir = "results" +exp_result_path = "results/L3M" device = "cuda" seed = 42 dtype = "torch.float32" diff --git a/examples/programmatic/analyze.py b/examples/programmatic/analyze.py index d60f259..cac1eaa 100644 --- a/examples/programmatic/analyze.py +++ b/examples/programmatic/analyze.py @@ -41,7 +41,7 @@ exp_name = "L3M", exp_series = "default", - exp_result_dir = "results", + exp_result_path = "results/L3M", )) sample_feature_activations_runner(cfg) \ No newline at end of file diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index 0f433a6..80e60ce 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -69,7 +69,7 @@ exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. exp_series = "test", - exp_result_dir = "results" + exp_result_path = "results/test" )) sparse_autoencoder = language_model_sae_runner(cfg) \ No newline at end of file diff --git a/server/app.py b/server/app.py index f9239eb..8562c81 100644 --- a/server/app.py +++ b/server/app.py @@ -46,7 +46,9 @@ def get_model(dictionary_name: str) -> HookedTransformer: - path = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)['path'] or f"{result_dir}/{dictionary_name}" + path = client.get_dictionary_path(dictionary_name, dictionary_series=dictionary_series) + if path is "": + path = f"{result_dir}/{dictionary_name}" cfg = LanguageModelConfig.from_pretrained_sae(path) if (cfg.model_name, cfg.model_from_pretrained_path) not in lm_cache: hf_model = AutoModelForCausalLM.from_pretrained( diff --git a/src/lm_saes/analysis/sample_feature_activations.py b/src/lm_saes/analysis/sample_feature_activations.py index d8c8a13..cf506ec 100644 --- a/src/lm_saes/analysis/sample_feature_activations.py +++ b/src/lm_saes/analysis/sample_feature_activations.py @@ -1,6 +1,7 @@ import os from typing import cast +from torch.distributed._tensor import DTensor from tqdm import tqdm import torch @@ -16,6 +17,8 @@ from lm_saes.activation.activation_store import ActivationStore from lm_saes.utils.misc import print_once from lm_saes.utils.tensor_dict import concat_dict_of_tensor, sort_dict_of_tensor +import torch.distributed as dist + @torch.no_grad() def sample_feature_activations( @@ -28,10 +31,14 @@ def sample_feature_activations( ): if sae.cfg.ddp_size > 1: raise ValueError("Sampling feature activations does not support DDP yet") - assert cfg.sae.d_sae is not None # Make mypy happy + assert cfg.sae.d_sae is not None # Make mypy happy total_analyzing_tokens = cfg.total_analyzing_tokens - total_analyzing_steps = total_analyzing_tokens // cfg.act_store.dataset.store_batch_size // cfg.act_store.dataset.context_size + total_analyzing_steps = ( + total_analyzing_tokens + // cfg.act_store.dataset.store_batch_size + // cfg.act_store.dataset.context_size + ) print_once(f"Total Analyzing Tokens: {total_analyzing_tokens}") print_once(f"Total Analyzing Steps: {total_analyzing_steps}") @@ -41,19 +48,43 @@ def sample_feature_activations( sae.eval() - pbar = tqdm(total=total_analyzing_tokens, desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}", smoothing=0.01) + pbar = tqdm( + total=total_analyzing_tokens, + desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}", + smoothing=0.01, + ) d_sae = cfg.sae.d_sae // n_sae_chunks - start_index = sae_chunk_id * d_sae - end_index = (sae_chunk_id + 1) * d_sae - - sample_result = {k: { - "elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device), - "feature_acts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=cfg.sae.dtype, device=cfg.sae.device), - "contexts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=torch.int32, device=cfg.sae.device), - } for k in cfg.subsample.keys()} + assert ( + d_sae // cfg.sae.tp_size * cfg.sae.tp_size == d_sae + ), "d_sae must be divisible by tp_size" + d_sae //= cfg.sae.tp_size + + rank = dist.get_rank() if cfg.sae.tp_size > 1 else 0 + start_index = sae_chunk_id * d_sae * cfg.sae.tp_size + d_sae * rank + end_index = sae_chunk_id * d_sae * cfg.sae.tp_size + d_sae * (rank + 1) + + sample_result = { + k: { + "elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device), + "feature_acts": torch.empty( + (0, d_sae, cfg.act_store.dataset.context_size), + dtype=cfg.sae.dtype, + device=cfg.sae.device, + ), + "contexts": torch.empty( + (0, d_sae, cfg.act_store.dataset.context_size), + dtype=torch.int32, + device=cfg.sae.device, + ), + } + for k in cfg.subsample.keys() + } act_times = torch.zeros((d_sae,), dtype=torch.long, device=cfg.sae.device) - feature_acts_all = [torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device) for _ in range(d_sae)] + feature_acts_all = [ + torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device) + for _ in range(d_sae) + ] max_feature_acts = torch.zeros((d_sae,), dtype=cfg.sae.dtype, device=cfg.sae.device) while n_training_tokens < total_analyzing_tokens: @@ -62,63 +93,110 @@ def sample_feature_activations( if batch is None: raise ValueError("Not enough tokens to sample") - _, cache = model.run_with_cache_until(batch, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out) - activation_in, activation_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] + _, cache = model.run_with_cache_until( + batch, + names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], + until=cfg.sae.hook_point_out, + ) + activation_in, activation_out = ( + cache[cfg.sae.hook_point_in], + cache[cfg.sae.hook_point_out], + ) filter_mask = torch.logical_or( batch.eq(model.tokenizer.eos_token_id), - batch.eq(model.tokenizer.pad_token_id) + batch.eq(model.tokenizer.pad_token_id), ) filter_mask = torch.logical_or( - filter_mask, - batch.eq(model.tokenizer.bos_token_id) + filter_mask, batch.eq(model.tokenizer.bos_token_id) ) - feature_acts = sae.encode(activation_in, label=activation_out)[..., start_index: end_index] + feature_acts = sae.encode(activation_in, label=activation_out)[ + ..., start_index:end_index + ] + if isinstance(feature_acts, DTensor): + feature_acts = feature_acts.to_local() feature_acts[filter_mask] = 0 - act_times += feature_acts.gt(0.0).sum(dim=[0, 1]) for name in cfg.subsample.keys(): if cfg.enable_sampling: - weights = feature_acts.clamp(min=0.0).pow(cfg.sample_weight_exponent).max(dim=1).values - elt = torch.rand(batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype).log() / weights + weights = ( + feature_acts.clamp(min=0.0) + .pow(cfg.sample_weight_exponent) + .max(dim=1) + .values + ) + elt = ( + torch.rand( + batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype + ).log() + / weights + ) elt[weights == 0.0] = -torch.inf else: elt = feature_acts.clamp(min=0.0).max(dim=1).values - elt[feature_acts.max(dim=1).values > max_feature_acts.unsqueeze(0) * cfg.subsample[name]["proportion"]] = -torch.inf + elt[ + feature_acts.max(dim=1).values + > max_feature_acts.unsqueeze(0) * cfg.subsample[name]["proportion"] + ] = -torch.inf - if sample_result[name]["elt"].size(0) > 0 and (elt.max(dim=0).values <= sample_result[name]["elt"][-1]).all(): + if ( + sample_result[name]["elt"].size(0) > 0 + and (elt.max(dim=0).values <= sample_result[name]["elt"][-1]).all() + ): continue sample_result[name] = concat_dict_of_tensor( sample_result[name], { "elt": elt, - "feature_acts": rearrange(feature_acts, 'batch_size context_size d_sae -> batch_size d_sae context_size'), - "contexts": repeat(batch.to(torch.int32), 'batch_size context_size -> batch_size d_sae context_size', d_sae=d_sae), + "feature_acts": rearrange( + feature_acts, + "batch_size context_size d_sae -> batch_size d_sae context_size", + ), + "contexts": repeat( + batch.to(torch.int32), + "batch_size context_size -> batch_size d_sae context_size", + d_sae=d_sae, + ), }, dim=0, ) - sample_result[name] = sort_dict_of_tensor(sample_result[name], sort_dim=0, sort_key="elt", descending=True) + sample_result[name] = sort_dict_of_tensor( + sample_result[name], sort_dim=0, sort_key="elt", descending=True + ) sample_result[name] = { - k: v[:cfg.subsample[name]["n_samples"]] for k, v in sample_result[name].items() + k: v[: cfg.subsample[name]["n_samples"]] + for k, v in sample_result[name].items() } - # Update feature activation histogram every 10 steps if n_training_steps % 50 == 49: - feature_acts_cur = rearrange(feature_acts, 'batch_size context_size d_sae -> d_sae (batch_size context_size)') + feature_acts_cur = rearrange( + feature_acts, + "batch_size context_size d_sae -> d_sae (batch_size context_size)", + ) for i in range(d_sae): - feature_acts_all[i] = torch.cat([feature_acts_all[i], feature_acts_cur[i][feature_acts_cur[i] > 0.0]], dim=0) - - max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values) + feature_acts_all[i] = torch.cat( + [ + feature_acts_all[i], + feature_acts_cur[i][feature_acts_cur[i] > 0.0], + ], + dim=0, + ) + + max_feature_acts = torch.max( + max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values + ) - n_tokens_current = torch.tensor(batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int) + n_tokens_current = torch.tensor( + batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int + ) n_training_tokens += cast(int, n_tokens_current.item()) n_training_steps += 1 @@ -126,12 +204,18 @@ def sample_feature_activations( pbar.close() - sample_result = {k1: { - k2: rearrange(v2, 'n_samples d_sae ... -> d_sae n_samples ...') for k2, v2 in v1.items() - } for k1, v1 in sample_result.items()} + sample_result = { + k1: { + k2: rearrange(v2, "n_samples d_sae ... -> d_sae n_samples ...") + for k2, v2 in v1.items() + } + for k1, v1 in sample_result.items() + } result = { - "index": torch.arange(start_index, end_index, device=cfg.sae.device, dtype=torch.int32), + "index": torch.arange( + start_index, end_index, device=cfg.sae.device, dtype=torch.int32 + ), "act_times": act_times, "feature_acts_all": feature_acts_all, "max_feature_acts": max_feature_acts, @@ -140,8 +224,9 @@ def sample_feature_activations( "name": k, "feature_acts": v["feature_acts"], "contexts": v["contexts"], - } for k, v in sample_result.items() + } + for k, v in sample_result.items() ], } - return result \ No newline at end of file + return result diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 25311fb..62fdded 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -53,13 +53,12 @@ def __post_init__(self): class RunnerConfig(BaseConfig): exp_name: str = "test" exp_series: Optional[str] = None - exp_result_dir: str = "results" + exp_result_path: str = "results" def __post_init__(self): super().__post_init__() if is_master(): - os.makedirs(self.exp_result_dir, exist_ok=True) - os.makedirs(os.path.join(self.exp_result_dir, self.exp_name), exist_ok=True) + os.makedirs(self.exp_result_path, exist_ok=True) @dataclass(kw_only=True) @@ -257,12 +256,12 @@ def from_pretrained( @deprecated("Use from_pretrained and to_dict instead.") @staticmethod def get_hyperparameters( - exp_name: str, exp_result_dir: str, ckpt_name: str, strict_loading: bool = True + exp_result_path: str, ckpt_name: str, strict_loading: bool = True ) -> dict[str, Any]: - with open(os.path.join(exp_result_dir, exp_name, "hyperparams.json"), "r") as f: + with open(os.path.join(exp_result_path, "hyperparams.json"), "r") as f: hyperparams = json.load(f) hyperparams["sae_pretrained_name_or_path"] = os.path.join( - exp_result_dir, exp_name, "checkpoints", ckpt_name + exp_result_path, "checkpoints", ckpt_name ) hyperparams["strict_loading"] = strict_loading # Remove non-hyperparameters from the dict @@ -350,13 +349,13 @@ def __post_init__(self): super().__post_init__() if is_master(): - # if os.path.exists( - # os.path.join(self.exp_result_dir, self.exp_name, "checkpoints") - # ): - # raise ValueError( - # f"Checkpoints for experiment {self.exp_name} already exist. Consider changing the experiment name." - # ) - os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), exist_ok=True) + if os.path.exists( + os.path.join(self.exp_result_path, "checkpoints") + ): + raise ValueError( + f"Checkpoints for experiment {self.exp_result_path} already exist. Consider changing the experiment name." + ) + os.makedirs(os.path.join(self.exp_result_path, "checkpoints")) self.effective_batch_size = self.train_batch_size * self.sae.ddp_size print_once(f"Effective batch size: {self.effective_batch_size}") @@ -418,7 +417,7 @@ def __post_init__(self): if is_master(): os.makedirs( - os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), + os.path.join(self.exp_result_path, "checkpoints"), exist_ok=True, ) diff --git a/src/lm_saes/database.py b/src/lm_saes/database.py index 0e68e98..02df8b8 100644 --- a/src/lm_saes/database.py +++ b/src/lm_saes/database.py @@ -58,8 +58,8 @@ def _remove_gridfs_objs(self, data): if isinstance(data, ObjectId) and self.fs.exists(data): self.fs.delete(data) - def create_dictionary(self, dictionary_name: str, n_features: int, dictionary_series: str | None = None): - dict_id = self.dictionary_collection.insert_one({'name': dictionary_name, 'n_features': n_features, 'series': dictionary_series}).inserted_id + def create_dictionary(self, dictionary_name: str, dictionary_path: str, n_features: int, dictionary_series: str | None = None): + dict_id = self.dictionary_collection.insert_one({'name': dictionary_name, 'n_features': n_features, 'series': dictionary_series, 'path': dictionary_path}).inserted_id self.feature_collection.insert_many([ { 'dictionary_id': dict_id, @@ -85,6 +85,7 @@ def update_feature(self, dictionary_name: str, feature_index: int, feature_data: self.feature_collection.update_one({'_id': feature['_id']}, {'$set': self._to_gridfs(feature_data)}) def list_dictionaries(self, dictionary_series: str | None = None): + # return [{'name': d['name'], 'path': d['path']} for d in self.dictionary_collection.find({'series': dictionary_series} if dictionary_series is not None else {})] return [d['name'] for d in self.dictionary_collection.find({'series': dictionary_series} if dictionary_series is not None else {})] def get_dictionary(self, dictionary_name: str, dictionary_series: str | None = None): @@ -211,3 +212,8 @@ def get_attn_head(self, dictionary_name: str, head_index: int, dictionary_series ] return self.attn_head_collection.aggregate(pipeline).next() + def get_dictionary_path(self, dictionary_name: str, dictionary_series: str | None = None): + dictionary = self.dictionary_collection.find_one({'name': dictionary_name, 'series': dictionary_series}) + if dictionary is None: + return None + return dictionary['path'] \ No newline at end of file diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 80febf2..6ffc2df 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -34,6 +34,7 @@ from torch.distributed.tensor.parallel import ( ColwiseParallel, + RowwiseParallel, parallelize_module, loss_parallel, ) @@ -97,8 +98,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): sae.train_finetune_for_suppression_parameters() if is_master(): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.sae.save_hyperparameters(cfg.exp_result_path) + cfg.lm.save_lm_config(cfg.exp_result_path) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { @@ -115,7 +116,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): entity=cfg.wandb.wandb_entity, ) with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "train_wandb_id.txt"), "w" + os.path.join(cfg.exp_result_path, "train_wandb_id.txt"), "w" ) as f: f.write(wandb_run.id) wandb.watch(sae, log="all") @@ -135,8 +136,8 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_path)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_path)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( ( @@ -184,7 +185,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): entity=cfg.wandb.wandb_entity, ) with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "prune_wandb_id.txt"), "w" + os.path.join(cfg.exp_result_path, "prune_wandb_id.txt"), "w" ) as f: f.write(wandb_run.id) @@ -236,6 +237,9 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) + model.offload_params_after( + cfg.act_store.hook_points[0], torch.tensor([[0]], device=cfg.lm.device) + ) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) @@ -254,7 +258,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): entity=cfg.wandb.wandb_entity, ) with open( - os.path.join(cfg.exp_result_dir, cfg.exp_name, "eval_wandb_id.txt"), "w" + os.path.join(cfg.exp_result_path, "eval_wandb_id.txt"), "w" ) as f: f.write(wandb_run.id) @@ -311,17 +315,16 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): if cfg.sae.tp_size > 1: plan = { "encoder": ColwiseParallel(output_layouts=Replicate()), + "decoder": RowwiseParallel(output_layouts=Replicate()), } if cfg.sae.use_glu_encoder: plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) - sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore + sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore sae.parallelize_plan = plan sae.decoder.weight = None # type: ignore[assignment] torch.cuda.empty_cache() - - hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name @@ -351,16 +354,17 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): dtype=cfg.lm.dtype, ) model.eval() - client = MongoClient(cfg.mongo.mongo_uri, cfg.mongo.mongo_db) - client.create_dictionary(cfg.exp_name, cfg.sae.d_sae, cfg.exp_series) + if is_master(): + client.create_dictionary( + cfg.exp_name, cfg.exp_result_path, cfg.sae.d_sae, cfg.exp_series + ) for chunk_id in range(cfg.n_sae_chunks): activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) result = sample_feature_activations( sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks ) - for i in range(len(result["index"].cpu().numpy().tolist())): client.update_feature( cfg.exp_name, diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 93045de..5629067 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -318,8 +318,7 @@ def train_sae( ): # Save the model and optimizer state path = os.path.join( - cfg.exp_result_dir, - cfg.exp_name, + cfg.exp_result_path, "checkpoints", f"{n_training_steps}.safetensors", ) @@ -345,7 +344,7 @@ def train_sae( if not cfg.sae.sparsity_include_decoder_norm: sae.set_decoder_norm_to_fixed_norm(1) path = os.path.join( - cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors" + cfg.exp_result_path, "checkpoints", "final.safetensors" ) sae.save_pretrained(path) @@ -447,7 +446,7 @@ def prune_sae( print("Total pruned features:", (sae.feature_act_mask == 0).sum().item()) path = os.path.join( - cfg.exp_result_dir, cfg.exp_name, "checkpoints", "pruned.safetensors" + cfg.exp_result_path, "checkpoints", "pruned.safetensors" ) sae.save_pretrained(path)