Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into 11-proposal-accelerate-inference-in-transformerlens
  • Loading branch information
StarConnor committed Jun 29, 2024
2 parents b85c98e + 7382893 commit a12cc22
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
3 changes: 1 addition & 2 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def fill_with_one_batch(self, batch, pack) -> None:
if self.is_dataset_tokenized:
tokens: torch.Tensor = batch["tokens"].to(self.device)
else:
tokens = self.model.to_tokens(batch["text"], prepend_bos=not pack).to(self.device)
tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device)
if pack:
while tokens.size(0) > 0:
cur_tokens = tokens[0]
Expand All @@ -59,7 +59,6 @@ def fill_with_one_batch(self, batch, pack) -> None:
if tokens.size(1) < self.seq_len:
pad_len = self.seq_len - tokens.size(1)
tokens = torch.cat([tokens, torch.full((tokens.size(0), pad_len), self.model.tokenizer.pad_token_id, dtype=torch.long, device=self.device)], dim=1)

self.token_buffer = torch.cat([self.token_buffer, tokens], dim=0)


Expand Down
22 changes: 17 additions & 5 deletions src/lm_saes/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def recons_loss_batched(
n_batches: int = 100,
):
losses = []
if (not cfg.use_ddp or cfg.rank == 0):
if not cfg.use_ddp or cfg.rank == 0:
pbar = tqdm(total=n_batches, desc="Evaluation", smoothing=0.01)
for _ in range(n_batches):
batch_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size)
Expand All @@ -118,10 +118,10 @@ def recons_loss_batched(
zero_abl_loss.mean().item(),
)
)
if (not cfg.use_ddp or cfg.rank == 0):
if not cfg.use_ddp or cfg.rank == 0:
pbar.update(1)

if (not cfg.use_ddp or cfg.rank == 0):
if not cfg.use_ddp or cfg.rank == 0:
pbar.close()

losses = pd.DataFrame(
Expand All @@ -139,7 +139,8 @@ def get_recons_loss(
batch_tokens: torch.Tensor,
):
batch_tokens = batch_tokens.to(torch.int64)
loss = model.forward(batch_tokens, return_type="loss")

loss = model.forward(batch_tokens, return_type="loss", loss_per_token=True)

_, cache = model.run_with_cache_until(
batch_tokens,
Expand All @@ -157,12 +158,23 @@ def replacement_hook(activations: torch.Tensor, hook: Any):
batch_tokens,
return_type="loss",
fwd_hooks=[(cfg.sae.hook_point_out, replacement_hook)],
loss_per_token=True
)

zero_abl_loss: torch.Tensor = model.run_with_hooks(
batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)]
batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)], loss_per_token=True
)

logits_mask = torch.logical_and(batch_tokens.ne(model.tokenizer.eos_token_id), batch_tokens.ne(model.tokenizer.pad_token_id))
logits_mask = torch.logical_and(logits_mask, batch_tokens.ne(model.tokenizer.bos_token_id))
logits_mask = logits_mask[:, 1:]

def get_useful_token_loss(per_token_loss):
per_token_loss = per_token_loss.where(logits_mask, 0)
return per_token_loss.sum() / per_token_loss.ne(0).sum()

loss, recons_loss, zero_abl_loss = get_useful_token_loss(loss), get_useful_token_loss(recons_loss), get_useful_token_loss(zero_abl_loss)

score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss)

return score, loss, recons_loss, zero_abl_loss
Expand Down

0 comments on commit a12cc22

Please sign in to comment.