Skip to content

Commit

Permalink
fix: generate activations
Browse files Browse the repository at this point in the history
dest1n1s committed Dec 3, 2024
1 parent 3f445b1 commit 2320def
Showing 6 changed files with 86 additions and 97 deletions.
2 changes: 1 addition & 1 deletion TransformerLens
5 changes: 2 additions & 3 deletions src/lm_saes/activation/activation_dataset.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,7 @@
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer

from ..config import (
ActivationGenerationConfig,
)
from ..config import ActivationGenerationConfig
from ..utils.misc import is_master, print_once
from .activation_store import ActivationStore
from .token_source import TokenSource
@@ -90,6 +88,7 @@ def make_activation_dataset(model: HookedTransformer, cfg: ActivationGenerationC
pbar = tqdm(
total=total_generating_tokens,
desc=f"Activation dataset Rank {dist.get_rank()}" if dist.is_initialized() else "Activation dataset",
smoothing=0.001,
)

while n_tokens < total_generating_tokens:
19 changes: 18 additions & 1 deletion src/lm_saes/activation/activation_store.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import torch.distributed._functional_collectives as funcol
import torch.utils.data
from torch.distributed.device_mesh import init_device_mesh
from tqdm import tqdm
from transformer_lens import HookedTransformer

from ..config import ActivationStoreConfig
@@ -26,7 +27,10 @@ def __init__(self, act_source: ActivationSource, cfg: ActivationStoreConfig):
self.tp_size = cfg.tp_size
self._store: Dict[str, torch.Tensor] = {}
self._all_gather_buffer: Dict[str, torch.Tensor] = {}
self.device_mesh = init_device_mesh("cuda", (self.ddp_size, self.tp_size), mesh_dim_names=("ddp", "tp"))
if self.tp_size > 1 or self.ddp_size > 1:
self.device_mesh = init_device_mesh("cuda", (self.ddp_size, self.tp_size), mesh_dim_names=("ddp", "tp"))
else:
self.device_mesh = None

def initialize(self):
self.refill()
@@ -41,6 +45,14 @@ def shuffle(self):
self._store[k] = self._store[k][perm]

def refill(self):
pbar = tqdm(
total=self.buffer_size,
desc="Refilling activation store",
smoothing=0,
leave=False,
initial=self.__len__(),
)
n_seqs = 0
while self.__len__() < self.buffer_size:
new_act = self.act_source.next()
if new_act is None:
@@ -53,6 +65,10 @@ def refill(self):
self._store[k] = torch.cat([self._store[k], v], dim=0)
# Check if all activations have the same size
assert len(set(v.size(0) for v in self._store.values())) == 1
n_seqs += 1
pbar.update(next(iter(new_act.values())).size(0))
pbar.set_postfix({"Sequences": n_seqs})
pbar.close()

def __len__(self):
if len(self._store) == 0:
@@ -75,6 +91,7 @@ def next(self, batch_size) -> Dict[str, torch.Tensor] | None:
if dist.is_initialized(): # Wait for all processes to refill the store
dist.barrier()
if self.tp_size > 1:
assert self.device_mesh is not None, "Device mesh not initialized"
for k, v in self._store.items():
if k not in self._all_gather_buffer:
self._all_gather_buffer[k] = torch.empty(size=(0,), dtype=v.dtype, device=self.device)
14 changes: 13 additions & 1 deletion src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,18 @@ def __init__(

def fill_with_one_batch(self, batch: dict[str, Any], pack: bool, prepend_bos: bool) -> None:
if self.is_dataset_tokenized:
tokens: torch.Tensor = batch["tokens"].to(self.device)
if isinstance(batch["input_ids"], torch.Tensor):
assert not batch["input_ids"].dtype.is_floating_point, "input_ids must be a tensor of integers"
tokens = batch["input_ids"].to(self.device)
else:
assert isinstance(batch["input_ids"], list), "input_ids must be a list or a tensor"
print("Batch size:", len(batch["input_ids"]), "Type:", type(batch["input_ids"]))
print("Sequence length:", len(batch["input_ids"][0]), "Type:", type(batch["input_ids"][0]))
# Check if all sequences in the batch have the same length
assert all(
len(seq) == len(batch["input_ids"][0]) for seq in batch["input_ids"]
), "All sequences must have the same length"
tokens = torch.tensor(batch["input_ids"], dtype=torch.long, device=self.device)
else:
tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device)
if pack:
@@ -124,6 +135,7 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig):
shard = dataset.shard(num_shards=dist.get_world_size(), index=shard_id, contiguous=True)
else:
shard = dataset
shard = shard.with_format("torch")

dataloader = DataLoader(
dataset=cast(Dataset[dict[str, Any]], shard), batch_size=cfg.store_batch_size, pin_memory=True
7 changes: 7 additions & 0 deletions src/lm_saes/entrypoint.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ class SupportedRunner(Enum):
EVAL = "eval"
ANALYZE = "analyze"
PRUNE = "prune"
GENERATE_ACTIVATIONS = "gen-activations"

def __str__(self):
return self.value
@@ -97,6 +98,12 @@ def entrypoint():

config = LanguageModelSAEPruningConfig.from_flattened(config)
language_model_sae_prune_runner(config)
elif args.runner == SupportedRunner.GENERATE_ACTIVATIONS:
from lm_saes.config import ActivationGenerationConfig
from lm_saes.runner import activation_generation_runner

config = ActivationGenerationConfig.from_flattened(config)
activation_generation_runner(config)
else:
raise ValueError(f"Unsupported runner: {args.runner}.")

136 changes: 45 additions & 91 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,12 @@
parallelize_module,
)
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
ChameleonForConditionalGeneration,
PreTrainedModel,
)

from .activation.activation_dataset import make_activation_dataset
from .activation.activation_source import CachedActivationSource
@@ -21,6 +26,7 @@
from .config import (
ActivationGenerationConfig,
FeaturesDecoderConfig,
LanguageModelConfig,
LanguageModelCrossCoderTrainingConfig,
LanguageModelSAEAnalysisConfig,
LanguageModelSAEPruningConfig,
@@ -36,36 +42,47 @@
from .utils.misc import is_master


def get_model(cfg: LanguageModelConfig):
if "chameleon" in cfg.model_name:
hf_model = ChameleonForConditionalGeneration.from_pretrained(
(cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path),
cache_dir=cfg.cache_dir,
local_files_only=cfg.local_files_only,
torch_dtype=cfg.dtype,
)
else:
hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
(cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path),
cache_dir=cfg.cache_dir,
local_files_only=cfg.local_files_only,
torch_dtype=cfg.dtype,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained_no_processing(
cfg.model_name,
use_flash_attn=cfg.use_flash_attn,
device=cfg.device,
cache_dir=cfg.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.dtype,
)
model.eval()
return model


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
if cfg.act_store.use_cached_activations:
activation_source = CachedActivationSource(cfg.act_store)
activation_store = ActivationStore(act_source=activation_source, cfg=cfg.act_store)
model = None
else:
hf_model = AutoModelForCausalLM.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
cache_dir=cfg.lm.cache_dir,
local_files_only=cfg.lm.local_files_only,
torch_dtype=cfg.lm.dtype,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)

model = HookedTransformer.from_pretrained_no_processing(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)
model.offload_params_after(cfg.act_store.hook_points[-1], torch.tensor([[0]], device=cfg.lm.device))
model.eval()
model = get_model(cfg.lm)
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if not cfg.finetuning and (
@@ -182,28 +199,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig):
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_path))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_path))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)
hf_model = AutoModelForCausalLM.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
cache_dir=cfg.lm.cache_dir,
local_files_only=cfg.lm.local_files_only,
torch_dtype=cfg.lm.dtype,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained_no_processing(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)
model.eval()
model = get_model(cfg.lm)
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)
if cfg.wandb.log_to_wandb and is_master():
wandb_config: dict = {
@@ -243,29 +239,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig):

def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig):
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)
hf_model = AutoModelForCausalLM.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
cache_dir=cfg.lm.cache_dir,
local_files_only=cfg.lm.local_files_only,
)

hf_tokenizer = AutoTokenizer.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained_no_processing(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)

model.eval()
model = get_model(cfg.lm)
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if cfg.wandb.log_to_wandb and is_master():
@@ -301,27 +275,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig):


def activation_generation_runner(cfg: ActivationGenerationConfig):
hf_model = AutoModelForCausalLM.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
cache_dir=cfg.lm.cache_dir,
local_files_only=cfg.lm.local_files_only,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(cfg.lm.model_name if cfg.lm.model_from_pretrained_path is None else cfg.lm.model_from_pretrained_path),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained_no_processing(
cfg.lm.model_name,
use_flash_attn=cfg.lm.use_flash_attn,
device=cfg.lm.device,
cache_dir=cfg.lm.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.lm.dtype,
)
model.eval()
model = get_model(cfg.lm)

make_activation_dataset(model, cfg)

0 comments on commit 2320def

Please sign in to comment.