Skip to content

Commit

Permalink
style: reformat some files
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Dec 19, 2024
1 parent 8f2757d commit d4f1da9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
27 changes: 19 additions & 8 deletions src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,33 +120,44 @@ def __init__(self, cfg: ActivationStoreConfig):

def load_chunk_into_buffer(self, dataset_id, chunk_path: list[str], ban_token_list=None):
if dataset_id not in self.chunk_buffer:
self.chunk_buffer[dataset_id] = torch.empty((0, self.cfg.lm.d_model), dtype=self.cfg.dtype, device=self.cfg.device)
self.chunk_buffer[dataset_id] = torch.empty(
(0, self.cfg.lm.d_model), dtype=self.cfg.dtype, device=self.cfg.device
)
to_fill_length = self.cfg.n_tokens_in_buffer // len(self.chunk_paths) - self.chunk_buffer[dataset_id].size(0)
while to_fill_length > 0 and len(chunk_path) > 0:
chunk = load_activation_chunk(chunk_path.pop(), self.cfg.device)
with_context = len(chunk["activation"].size()) == 3
activation = chunk["activation"]
if with_context:
chunk["context"] = chunk["context"].to(dtype=torch.long, device=self.cfg.device)
not_ban_token = torch.isin(rearrange(chunk["context"], "b l -> (b l)"), torch.tensor(ban_token_list, device=self.cfg.device), invert=True)
activation = rearrange(chunk["activation"], "b l d -> (b l) d")[not_ban_token]
not_ban_token = torch.isin(
rearrange(chunk["context"], "b l -> (b l)"),
torch.tensor(ban_token_list, device=self.cfg.device),
invert=True,
)
activation = rearrange(chunk["activation"], "b l d -> (b l) d")[not_ban_token]
self.chunk_buffer[dataset_id] = torch.cat([self.chunk_buffer[dataset_id], activation], dim=0)
to_fill_length -= activation.size(0)
return chunk_path


def next(self)-> Dict[str, torch.Tensor] | None:

def next(self) -> Dict[str, torch.Tensor] | None:
for i, chunk_paths in enumerate(self.chunk_paths):
self.sample_probs[i] = 0 if len(chunk_paths) == 0 else self.sample_probs[i]
self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs]

for i, chunk_paths in enumerate(self.chunk_paths):
self.chunk_paths[i] = self.load_chunk_into_buffer(i, chunk_paths, self.cfg.ban_token_list[i])

next_length_list = [min(int(self.sample_probs[i] * self.cfg.n_tokens_in_buffer), self.chunk_buffer[i].size(0)) for i in range(len(self.chunk_paths))]
next_length_list = [
min(int(self.sample_probs[i] * self.cfg.n_tokens_in_buffer), self.chunk_buffer[i].size(0))
for i in range(len(self.chunk_paths))
]

ret = {self.hook_point: torch.cat([self.chunk_buffer[i][:next_length_list[i]] for i in range(len(self.chunk_paths))], dim=0)}
ret = {
self.hook_point: torch.cat(
[self.chunk_buffer[i][: next_length_list[i]] for i in range(len(self.chunk_paths))], dim=0
)
}
return ret

# def next(self) -> Dict[str, torch.Tensor] | None:
Expand Down
1 change: 0 additions & 1 deletion src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class ActivationStoreConfig(BaseModelConfig, RunnerConfig):
hook_points: List[str] = field(default_factory=lambda: ["blocks.0.hook_resid_pre"])
""" Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. """


use_cached_activations: bool = False
cached_activations_path: List[str] = None # type: ignore
shuffle_activations: bool = True
Expand Down

0 comments on commit d4f1da9

Please sign in to comment.