From b2d23803eff9794ddcb4ab8c69fd1841f47f39c9 Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Tue, 2 Jul 2024 17:20:45 +0800 Subject: [PATCH] feature(activation): add an option to prepend bos before each sentence. Recommend to be True except GPT2 SAEs --- src/lm_saes/activation/activation_source.py | 2 +- src/lm_saes/activation/token_source.py | 9 ++++++--- src/lm_saes/config.py | 5 +++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/activation/activation_source.py b/src/lm_saes/activation/activation_source.py index ce8a01c5..d0b5d40e 100644 --- a/src/lm_saes/activation/activation_source.py +++ b/src/lm_saes/activation/activation_source.py @@ -40,7 +40,7 @@ def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig): self.cfg = cfg def next(self) -> Dict[str, torch.Tensor] | None: - tokens = self.token_source.next(self.cfg.dataset.store_batch_size) + tokens = self.next_tokens(self.cfg.dataset.store_batch_size) if tokens is None: return None diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index 2cb7f5ec..f825e54e 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -17,6 +17,7 @@ def __init__( concat_tokens: list[bool], seq_len: int, sample_probs: list[float], + prepend_bos: list[bool] ): self.dataloader = dataloader self.model = model @@ -33,13 +34,14 @@ def __init__( self.resid = torch.tensor([], dtype=torch.long, device=self.device) self.sample_probs = sample_probs + self.prepend_bos = prepend_bos - def fill_with_one_batch(self, batch, pack) -> None: + def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None: if self.is_dataset_tokenized: tokens: torch.Tensor = batch["tokens"].to(self.device) else: - tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device) + tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device) if pack: while tokens.size(0) > 0: cur_tokens = tokens[0] @@ -81,7 +83,7 @@ def next(self, batch_size: int) -> torch.Tensor | None: else: return None - self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch]) + self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch], prepend_bos=self.prepend_bos[dataset_idx_to_fetch]) ret = self.token_buffer[:batch_size] self.token_buffer = self.token_buffer[batch_size:] @@ -120,4 +122,5 @@ def from_config(model: HookedTransformer, cfg: TextDatasetConfig): concat_tokens=cfg.concat_tokens, seq_len=cfg.context_size, sample_probs=cfg.sample_probs, + prepend_bos=cfg.prepend_bos ) \ No newline at end of file diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 438835aa..808719e9 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -106,6 +106,7 @@ class TextDatasetConfig(RunnerConfig): context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) + prepend_bos: List[bool] = field(default_factory=lambda: [False]) def __post_init__(self): super().__post_init__() @@ -115,10 +116,14 @@ def __post_init__(self): if isinstance(self.concat_tokens, bool): self.concat_tokens = [self.concat_tokens] + if isinstance(self.prepend_bos, bool): + self.prepend_bos = [self.prepend_bos] + self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs] assert len(self.sample_probs) == len(self.dataset_path), "Number of sample_probs must match number of dataset paths" assert len(self.concat_tokens) == len(self.dataset_path), "Number of concat_tokens must match number of dataset paths" + assert len(self.prepend_bos) == len(self.dataset_path), "Number of prepend_bos must match number of dataset paths" @dataclass(kw_only=True)