Skip to content

Commit

Permalink
feature(activation): add an option to prepend bos before each sentenc…
Browse files Browse the repository at this point in the history
…e. Recommend to be True except GPT2 SAEs

Co-authored-by: Hzfinfdu <[email protected]>
  • Loading branch information
Hzfinfdu and Hzfinfdu authored Jul 2, 2024
1 parent 5ef3cba commit d844b51
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig):
self.cfg = cfg

def next(self) -> Dict[str, torch.Tensor] | None:
tokens = self.token_source.next(self.cfg.dataset.store_batch_size)
tokens = self.next_tokens(self.cfg.dataset.store_batch_size)

if tokens is None:
return None
Expand Down
9 changes: 6 additions & 3 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
concat_tokens: list[bool],
seq_len: int,
sample_probs: list[float],
prepend_bos: list[bool]
):
self.dataloader = dataloader
self.model = model
Expand All @@ -33,13 +34,14 @@ def __init__(
self.resid = torch.tensor([], dtype=torch.long, device=self.device)

self.sample_probs = sample_probs
self.prepend_bos = prepend_bos


def fill_with_one_batch(self, batch, pack) -> None:
def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None:
if self.is_dataset_tokenized:
tokens: torch.Tensor = batch["tokens"].to(self.device)
else:
tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device)
tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device)
if pack:
while tokens.size(0) > 0:
cur_tokens = tokens[0]
Expand Down Expand Up @@ -81,7 +83,7 @@ def next(self, batch_size: int) -> torch.Tensor | None:
else:
return None

self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch])
self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch], prepend_bos=self.prepend_bos[dataset_idx_to_fetch])

ret = self.token_buffer[:batch_size]
self.token_buffer = self.token_buffer[batch_size:]
Expand Down Expand Up @@ -120,4 +122,5 @@ def from_config(model: HookedTransformer, cfg: TextDatasetConfig):
concat_tokens=cfg.concat_tokens,
seq_len=cfg.context_size,
sample_probs=cfg.sample_probs,
prepend_bos=cfg.prepend_bos
)
5 changes: 5 additions & 0 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class TextDatasetConfig(RunnerConfig):
context_size: int = 128
store_batch_size: int = 64
sample_probs: List[float] = field(default_factory=lambda: [1.0])
prepend_bos: List[bool] = field(default_factory=lambda: [False])

def __post_init__(self):
super().__post_init__()
Expand All @@ -115,10 +116,14 @@ def __post_init__(self):
if isinstance(self.concat_tokens, bool):
self.concat_tokens = [self.concat_tokens]

if isinstance(self.prepend_bos, bool):
self.prepend_bos = [self.prepend_bos]

self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs]

assert len(self.sample_probs) == len(self.dataset_path), "Number of sample_probs must match number of dataset paths"
assert len(self.concat_tokens) == len(self.dataset_path), "Number of concat_tokens must match number of dataset paths"
assert len(self.prepend_bos) == len(self.dataset_path), "Number of prepend_bos must match number of dataset paths"


@dataclass(kw_only=True)
Expand Down

0 comments on commit d844b51

Please sign in to comment.