Skip to content

Commit

Permalink
Merge pull request #50 from OpenMOSS/ft4supp
Browse files Browse the repository at this point in the history
fix(sae): support grid searching for best init
  • Loading branch information
Hzfinfdu authored Aug 19, 2024
2 parents 654970b + ef0a24a commit b7cb227
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def compute_loss(
| None
) = None,
return_aux_data: bool = True,
during_init: bool = False,
) -> Union[
Float[torch.Tensor, "batch"],
tuple[
Expand Down Expand Up @@ -373,7 +374,7 @@ def compute_loss(
if self.cfg.sparsity_include_decoder_norm:

l_l1 = torch.norm(
feature_acts_normed * self.decoder_norm(),
feature_acts_normed * self.decoder_norm(during_init=during_init),
p=self.cfg.lp,
dim=-1,
)
Expand Down Expand Up @@ -669,7 +670,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
test_sae.encoder.weight.data = (
test_sae.decoder.weight.data.T.clone().contiguous()
)
mse = test_sae.compute_loss(x=activation_in, label=activation_out)[1][0]["l_rec"].mean().item() # type: ignore
mse = test_sae.compute_loss(x=activation_in, label=activation_out, during_init=True)[1][0]["l_rec"].mean().item() # type: ignore
losses[norm] = mse
best_norm = min(losses, key=losses.get) # type: ignore
return best_norm
Expand All @@ -687,7 +688,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
)

test_sae.set_decoder_norm_to_fixed_norm(
best_norm_fine_grained, force_exact=True
best_norm_fine_grained, force_exact=True, during_init=True
)
test_sae.encoder.weight.data = (
test_sae.decoder.weight.data.T.clone().contiguous()
Expand Down

0 comments on commit b7cb227

Please sign in to comment.