diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 416ae46..b86569b 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -515,7 +515,7 @@ def transform_to_unit_decoder_norm(self): decoder_norm = self.decoder_norm() # (d_sae,) self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None] - self.decoder.weight.data = self.decoder.weight.data.T / decoder_norm + self.decoder.weight.data = self.decoder.weight.data / decoder_norm self.encoder.bias.data = self.encoder.bias.data * decoder_norm