diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index a215c02..2cb7f5e 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -39,7 +39,7 @@ def fill_with_one_batch(self, batch, pack) -> None: if self.is_dataset_tokenized: tokens: torch.Tensor = batch["tokens"].to(self.device) else: - tokens = self.model.to_tokens(batch["text"], prepend_bos=not pack).to(self.device) + tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device) if pack: while tokens.size(0) > 0: cur_tokens = tokens[0] @@ -59,7 +59,6 @@ def fill_with_one_batch(self, batch, pack) -> None: if tokens.size(1) < self.seq_len: pad_len = self.seq_len - tokens.size(1) tokens = torch.cat([tokens, torch.full((tokens.size(0), pad_len), self.model.tokenizer.pad_token_id, dtype=torch.long, device=self.device)], dim=1) - self.token_buffer = torch.cat([self.token_buffer, tokens], dim=0)