diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index 3b15c559..65e9681b 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -116,13 +116,13 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig): if dist.is_initialized(): shard_id = dist.get_rank() shard = dataset.shard( - num_shards=dist.get_world_size(), index=shard_id + num_shards=dist.get_world_size(), index=shard_id, contiguous=True ) else: shard = dataset - dataloader = DataLoader(shard, batch_size=cfg.store_batch_size) + dataloader = DataLoader(shard, batch_size=cfg.store_batch_size, num_workers=4, prefetch_factor=4, pin_memory=True) return dataloader @staticmethod diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 82e6769a..75b50c3d 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -324,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): lr_warm_up_steps: int | float = 0.1 lr_cool_down_steps: int | float = 0.1 train_batch_size: int = 4096 - clip_grad_value: float = 0.0 + clip_grad_norm: float = 0.0 remove_gradient_parallel_to_decoder_directions: bool = False finetuning: bool = False diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 395d3fcb..eb5d4e8c 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -32,6 +32,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP from lm_saes.utils.misc import is_master +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + loss_parallel, +) +from torch.distributed._tensor import ( + DTensor, + Shard, + Replicate, + distribute_module, + distribute_tensor, +) + def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): if is_master(): @@ -77,24 +90,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - if ( - cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None - or cfg.sae.init_decoder_norm is None - ): - assert not cfg.finetuning - sae = SparseAutoEncoder.from_initialization_searching( - activation_store=activation_store, - cfg=cfg, - ) - else: - sae = SparseAutoEncoder.from_config(cfg=cfg.sae) - - if cfg.finetuning: - # Fine-tune SAE with frozen encoder weights and bias - sae.train_finetune_for_suppression_parameters() - cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { @@ -304,6 +300,20 @@ def activation_generation_runner(cfg: ActivationGenerationConfig): def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + if cfg.sae.tp_size > 1: + plan = { + "encoder": ColwiseParallel(output_layouts=Replicate()), + } + if cfg.sae.use_glu_encoder: + plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) + sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore + sae.parallelize_plan = plan + + sae.decoder.weight = None # type: ignore[assignment] + torch.cuda.empty_cache() + + + hf_model = AutoModelForCausalLM.from_pretrained( ( cfg.lm.model_name diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index dc583ac9..b5c9d2ec 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -145,8 +145,9 @@ def train_sae( if cfg.finetuning: loss = loss_data["l_rec"].mean() loss.backward() - if cfg.clip_grad_value > 0: - torch.nn.utils.clip_grad_value_(sae.parameters(), cfg.clip_grad_value) + grad_norm = torch.tensor([0.0], device=cfg.sae.device) + if cfg.clip_grad_norm > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) if cfg.remove_gradient_parallel_to_decoder_directions: sae.remove_gradient_parallel_to_decoder_directions() optimizer.step() @@ -171,13 +172,13 @@ def train_sae( if cfg.wandb.log_to_wandb and (is_master()): feature_sparsity = act_freq_scores / n_frac_active_tokens log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) - wandb_histogram = wandb.Histogram( - log_feature_sparsity.detach().cpu().float().numpy() - ) + # wandb_histogram = wandb.Histogram( + # log_feature_sparsity.detach().cpu().float().numpy() + # ) wandb.log( { "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(), - "plots/feature_density_line_chart": wandb_histogram, + # "plots/feature_density_line_chart": wandb_histogram, "sparsity/below_1e-5": (feature_sparsity < 1e-5) .sum() .item(), @@ -285,8 +286,9 @@ def train_sae( # norm "metrics/decoder_norm": decoder_norm.item(), "metrics/encoder_norm": encoder_norm.item(), - "metrics/decoder_bias_mean": sae.decoder.bias.mean().item() if sae.cfg.use_decoder_bias else 0, - "metrics/enocder_bias_mean": sae.encoder.bias.mean().item(), + "metrics/decoder_bias_norm": sae.decoder.bias.norm().item() if sae.cfg.use_decoder_bias else 0, + "metrics/encoder_bias_norm": sae.encoder.bias.norm().item(), + "metrics/gradients_norm": grad_norm.item(), # sparsity "sparsity/l1_coefficient": sae.current_l1_coefficient, "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(),