diff --git a/src/lm_saes/evals.py b/src/lm_saes/evals.py index 5ac765d..169a22c 100644 --- a/src/lm_saes/evals.py +++ b/src/lm_saes/evals.py @@ -97,7 +97,7 @@ def recons_loss_batched( n_batches: int = 100, ): losses = [] - if (not cfg.use_ddp or cfg.rank == 0): + if not cfg.use_ddp or cfg.rank == 0: pbar = tqdm(total=n_batches, desc="Evaluation", smoothing=0.01) for _ in range(n_batches): batch_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size) @@ -118,10 +118,10 @@ def recons_loss_batched( zero_abl_loss.mean().item(), ) ) - if (not cfg.use_ddp or cfg.rank == 0): + if not cfg.use_ddp or cfg.rank == 0: pbar.update(1) - if (not cfg.use_ddp or cfg.rank == 0): + if not cfg.use_ddp or cfg.rank == 0: pbar.close() losses = pd.DataFrame( @@ -139,7 +139,8 @@ def get_recons_loss( batch_tokens: torch.Tensor, ): batch_tokens = batch_tokens.to(torch.int64) - loss = model.forward(batch_tokens, return_type="loss") + + loss = model.forward(batch_tokens, return_type="loss", loss_per_token=True) _, cache = model.run_with_cache_until( batch_tokens, @@ -157,12 +158,23 @@ def replacement_hook(activations: torch.Tensor, hook: Any): batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, replacement_hook)], + loss_per_token=True ) zero_abl_loss: torch.Tensor = model.run_with_hooks( - batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)] + batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)], loss_per_token=True ) + logits_mask = torch.logical_and(batch_tokens.ne(model.tokenizer.eos_token_id), batch_tokens.ne(model.tokenizer.pad_token_id)) + logits_mask = torch.logical_and(logits_mask, batch_tokens.ne(model.tokenizer.bos_token_id)) + logits_mask = logits_mask[:, 1:] + + def get_useful_token_loss(per_token_loss): + per_token_loss = per_token_loss.where(logits_mask, 0) + return per_token_loss.sum() / per_token_loss.ne(0).sum() + + loss, recons_loss, zero_abl_loss = get_useful_token_loss(loss), get_useful_token_loss(recons_loss), get_useful_token_loss(zero_abl_loss) + score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss) return score, loss, recons_loss, zero_abl_loss