Skip to content

Commit

Permalink
Merge pull request #37 from OpenMOSS/zfhe
Browse files Browse the repository at this point in the history
Zfhe
  • Loading branch information
Hzfinfdu authored Jul 25, 2024
2 parents 3019a37 + ecb8fa5 commit 3d0fd81
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig):
lr_warm_up_steps: int | float = 0.1
lr_cool_down_steps: int | float = 0.1
train_batch_size: int = 4096
clip_grad_norm: float = 0.0
clip_grad_value: float = 0.0
remove_gradient_parallel_to_decoder_directions: bool = False

finetuning: bool = False
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def transform_to_unit_decoder_norm(self):

decoder_norm = self.decoder_norm() # (d_sae,)
self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None]
self.decoder.weight.data = self.decoder.weight.data.T / decoder_norm
self.decoder.weight.data = self.decoder.weight.data / decoder_norm

self.encoder.bias.data = self.encoder.bias.data * decoder_norm

Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def train_sae(
if cfg.finetuning:
loss = loss_data["l_rec"].mean()
loss.backward()
if cfg.clip_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm)
if cfg.clip_grad_value > 0:
torch.nn.utils.clip_grad_value_(sae.parameters(), cfg.clip_grad_value)
if cfg.remove_gradient_parallel_to_decoder_directions:
sae.remove_gradient_parallel_to_decoder_directions()
optimizer.step()
Expand Down

0 comments on commit 3d0fd81

Please sign in to comment.