From ef0a24a09924b401c187ce259862497cea0a42ed Mon Sep 17 00:00:00 2001 From: root <2247778946@qq.com> Date: Fri, 16 Aug 2024 09:09:26 +0000 Subject: [PATCH] fix(sae): support grid searching for best init --- src/lm_saes/sae.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 1f011ab..cfc2702 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -332,6 +332,7 @@ def compute_loss( | None ) = None, return_aux_data: bool = True, + during_init: bool = False, ) -> Union[ Float[torch.Tensor, "batch"], tuple[ @@ -373,7 +374,7 @@ def compute_loss( if self.cfg.sparsity_include_decoder_norm: l_l1 = torch.norm( - feature_acts_normed * self.decoder_norm(), + feature_acts_normed * self.decoder_norm(during_init=during_init), p=self.cfg.lp, dim=-1, ) @@ -669,7 +670,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: test_sae.encoder.weight.data = ( test_sae.decoder.weight.data.T.clone().contiguous() ) - mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]["l_rec"].mean().item() # type: ignore + mse = test_sae.compute_loss(x=activation_in, label=activation_out, during_init=True)[1][0]["l_rec"].mean().item() # type: ignore losses[norm] = mse best_norm = min(losses, key=losses.get) # type: ignore return best_norm @@ -687,7 +688,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: ) test_sae.set_decoder_norm_to_fixed_norm( - best_norm_fine_grained, force_exact=True + best_norm_fine_grained, force_exact=True, during_init=True ) test_sae.encoder.weight.data = ( test_sae.decoder.weight.data.T.clone().contiguous()