diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 6c5abae..1630057 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -118,7 +118,9 @@ def initialize_parameters(self): if self.cfg.init_encoder_with_decoder_transpose: self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous() else: - self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm, during_init=True) + self.set_encoder_norm_to_fixed_norm( + self.cfg.init_encoder_norm, during_init=True + ) def train_base_parameters(self): """Set the base parameters to be trained.""" @@ -264,7 +266,7 @@ def encode( if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: x = ( - x - self.decoder.bias.to_local() # type: ignore + x - self.decoder.bias.to_local() # type: ignore if self.cfg.tp_size > 1 else x - self.decoder.bias ) @@ -483,22 +485,23 @@ def set_decoder_norm_to_fixed_norm( if force_exact is None: force_exact = self.cfg.decoder_exactly_fixed_norm - if self.cfg.tp_size > 1 and not during_init: decoder_norm = distribute_tensor( - decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] + decoder_norm, + device_mesh=self.device_mesh["tp"], + placements=[Shard(0)], ) if force_exact: - self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm + self.decoder.weight.data *= value / decoder_norm else: # Set the norm of the decoder to not exceed value - self.decoder.weight.data = ( - self.decoder.weight.data * value / torch.clamp(decoder_norm, min=value) - ) + self.decoder.weight.data *= value / torch.clamp(decoder_norm, min=value) @torch.no_grad() - def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0, during_init: bool = False): + def set_encoder_norm_to_fixed_norm( + self, value: float | None = 1.0, during_init: bool = False + ): if self.cfg.use_glu_encoder: raise NotImplementedError("GLU encoder not supported") if value is None: @@ -509,43 +512,11 @@ def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0, during_init: encoder_norm = self.encoder_norm(keepdim=True, during_init=during_init) if self.cfg.tp_size > 1 and not during_init: encoder_norm = distribute_tensor( - encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] - ) - self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm - - @torch.no_grad() - def transform_to_unit_decoder_norm(self): - """ - If we include decoder norm in the sparsity loss, the final decoder norm is not guaranteed to be 1. - We make an equivalent transformation to the decoder to make it unit norm. - See https://transformer-circuits.pub/2024/april-update/index.html#training-saes - """ - assert ( - self.cfg.sparsity_include_decoder_norm - ), "Decoder norm is not included in the sparsity loss" - if self.cfg.use_glu_encoder: - raise NotImplementedError("GLU encoder not supported") - - decoder_norm = self.decoder_norm() # (d_sae,) - if self.cfg.tp_size > 1: - decoder_norm_en = distribute_tensor( - decoder_norm[:, None], device_mesh=self.device_mesh["tp"], placements=[Replicate()] - ) - decoder_norm_de = distribute_tensor( - decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] - ) - dencoder_norm_bias = distribute_tensor( - decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()] + encoder_norm, + device_mesh=self.device_mesh["tp"], + placements=[Shard(0)], ) - else: - decoder_norm_en = decoder_norm[:, None] - decoder_norm_de = decoder_norm - dencoder_norm_bias = decoder_norm - - self.encoder.weight.data = self.encoder.weight.data * decoder_norm_en - self.decoder.weight.data = self.decoder.weight.data / decoder_norm_de - - self.encoder.bias.data = self.encoder.bias.data * dencoder_norm_bias + self.encoder.weight.data *= (value / encoder_norm) @torch.no_grad() def remove_gradient_parallel_to_decoder_directions(self): @@ -651,9 +622,7 @@ def from_initialization_searching( activation_store: ActivationStore, cfg: LanguageModelSAETrainingConfig, ): - test_batch = activation_store.next( - batch_size=cfg.train_batch_size - ) + test_batch = activation_store.next(batch_size=cfg.train_batch_size) activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore if ( @@ -746,11 +715,27 @@ def save_pretrained(self, ckpt_path: str) -> None: if os.path.isdir(ckpt_path): ckpt_path = os.path.join(ckpt_path, "sae_weights.safetensors") state_dict = self.get_full_state_dict() + + @torch.no_grad() + def transform_to_unit_decoder_norm( + state_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + decoder_norm = torch.norm( + state_dict["decoder.weight"], p=2, dim=0, keepdim=False + ) + state_dict["decoder.weight"] = state_dict["decoder.weight"] / decoder_norm + state_dict["encoder.weight"] = ( + state_dict["encoder.weight"] * decoder_norm[:, None] + ) + state_dict["encoder.bias"] = state_dict["encoder.bias"] * decoder_norm + return state_dict + + if self.cfg.sparsity_include_decoder_norm: + state_dict = transform_to_unit_decoder_norm(state_dict) + if is_master(): if ckpt_path.endswith(".safetensors"): - safe.save_file( - state_dict, ckpt_path, {"version": version("lm-saes")} - ) + safe.save_file(state_dict, ckpt_path, {"version": version("lm-saes")}) elif ckpt_path.endswith(".pt"): torch.save( {"sae": state_dict, "version": version("lm-saes")}, ckpt_path @@ -766,8 +751,8 @@ def decoder_norm(self, keepdim: bool = False, during_init: bool = False): return torch.norm(self.decoder.weight, p=2, dim=0, keepdim=keepdim) else: decoder_norm = torch.norm( - self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim # type: ignore - ) + self.decoder.weight.to_local(), p=2, dim=0, keepdim=keepdim # type: ignore + ) decoder_norm = DTensor.from_local( decoder_norm, device_mesh=self.device_mesh["tp"], @@ -787,8 +772,8 @@ def encoder_norm( return torch.norm(self.encoder.weight, p=2, dim=1, keepdim=keepdim) else: encoder_norm = torch.norm( - self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim # type: ignore - ) + self.encoder.weight.to_local(), p=2, dim=1, keepdim=keepdim # type: ignore + ) encoder_norm = DTensor.from_local( encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Shard(0)] ) diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index b5c9d2e..93045de 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -342,9 +342,7 @@ def train_sae( pbar.close() # Save the final model - if cfg.sae.sparsity_include_decoder_norm: - sae.transform_to_unit_decoder_norm() - else: + if not cfg.sae.sparsity_include_decoder_norm: sae.set_decoder_norm_to_fixed_norm(1) path = os.path.join( cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors"