From 4c11b5d2102ff7fc60a23fbaa72d9954558204ff Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 26 Jun 2024 15:43:48 +0800 Subject: [PATCH] fix(activation gen): remove bos. ce score is greatly improved --- src/lm_saes/activation/token_source.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index a215c02a..2cb7f5ec 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -39,7 +39,7 @@ def fill_with_one_batch(self, batch, pack) -> None: if self.is_dataset_tokenized: tokens: torch.Tensor = batch["tokens"].to(self.device) else: - tokens = self.model.to_tokens(batch["text"], prepend_bos=not pack).to(self.device) + tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device) if pack: while tokens.size(0) > 0: cur_tokens = tokens[0] @@ -59,7 +59,6 @@ def fill_with_one_batch(self, batch, pack) -> None: if tokens.size(1) < self.seq_len: pad_len = self.seq_len - tokens.size(1) tokens = torch.cat([tokens, torch.full((tokens.size(0), pad_len), self.model.tokenizer.pad_token_id, dtype=torch.long, device=self.device)], dim=1) - self.token_buffer = torch.cat([self.token_buffer, tokens], dim=0)