diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 75b50c3..82e6769 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -324,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): lr_warm_up_steps: int | float = 0.1 lr_cool_down_steps: int | float = 0.1 train_batch_size: int = 4096 - clip_grad_norm: float = 0.0 + clip_grad_value: float = 0.0 remove_gradient_parallel_to_decoder_directions: bool = False finetuning: bool = False diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index df7f8a7..d238422 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -515,7 +515,7 @@ def transform_to_unit_decoder_norm(self): decoder_norm = self.decoder_norm() # (d_sae,) self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None] - self.decoder.weight.data = self.decoder.weight.data.T / decoder_norm + self.decoder.weight.data = self.decoder.weight.data / decoder_norm self.encoder.bias.data = self.encoder.bias.data * decoder_norm diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index b55746d..dc583ac 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -145,8 +145,8 @@ def train_sae( if cfg.finetuning: loss = loss_data["l_rec"].mean() loss.backward() - if cfg.clip_grad_norm > 0: - torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) + if cfg.clip_grad_value > 0: + torch.nn.utils.clip_grad_value_(sae.parameters(), cfg.clip_grad_value) if cfg.remove_gradient_parallel_to_decoder_directions: sae.remove_gradient_parallel_to_decoder_directions() optimizer.step()