Skip to content

Commit

Permalink
fix(sae): support grid searching for best init
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingDutchman26 committed Aug 16, 2024
1 parent acfac55 commit ef0a24a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def compute_loss(
| None
) = None,
return_aux_data: bool = True,
during_init: bool = False,
) -> Union[
Float[torch.Tensor, "batch"],
tuple[
Expand Down Expand Up @@ -373,7 +374,7 @@ def compute_loss(
if self.cfg.sparsity_include_decoder_norm:

l_l1 = torch.norm(
feature_acts_normed * self.decoder_norm(),
feature_acts_normed * self.decoder_norm(during_init=during_init),
p=self.cfg.lp,
dim=-1,
)
Expand Down Expand Up @@ -669,7 +670,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
test_sae.encoder.weight.data = (
test_sae.decoder.weight.data.T.clone().contiguous()
)
mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]["l_rec"].mean().item() # type: ignore
mse = test_sae.compute_loss(x=activation_in, label=activation_out, during_init=True)[1][0]["l_rec"].mean().item() # type: ignore
losses[norm] = mse
best_norm = min(losses, key=losses.get) # type: ignore
return best_norm
Expand All @@ -687,7 +688,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
)

test_sae.set_decoder_norm_to_fixed_norm(
best_norm_fine_grained, force_exact=True
best_norm_fine_grained, force_exact=True, during_init=True
)
test_sae.encoder.weight.data = (
test_sae.decoder.weight.data.T.clone().contiguous()
Expand Down

0 comments on commit ef0a24a

Please sign in to comment.