From ccac63a135f9c816cd5445da850c0a47e9eeabb8 Mon Sep 17 00:00:00 2001 From: Frankstein <20307140057@fudan.edu.cn> Date: Sun, 14 Jul 2024 18:28:32 +0800 Subject: [PATCH] feat: Implement tensor parallelism in SAE using device mesh - Introduced tensor parallelism which may contain bugs; further testing required. Existing data parallelism remains functional. - Save and load functionalities for SAE under tensor parallelism are yet to be implemented. - Additional tests needed for SAE and activation store under tensor parallelism. --- src/lm_saes/activation/activation_dataset.py | 13 +- src/lm_saes/activation/activation_source.py | 4 +- src/lm_saes/activation/activation_store.py | 102 ++++-- src/lm_saes/activation/token_source.py | 61 ++-- src/lm_saes/config.py | 108 +++--- src/lm_saes/entrypoint.py | 9 +- src/lm_saes/evals.py | 35 +- src/lm_saes/runner.py | 50 ++- src/lm_saes/sae.py | 336 ++++++++++++------- src/lm_saes/sae_training.py | 123 +++++-- src/lm_saes/utils/misc.py | 7 +- 11 files changed, 558 insertions(+), 290 deletions(-) diff --git a/src/lm_saes/activation/activation_dataset.py b/src/lm_saes/activation/activation_dataset.py index 9caba45..b596265 100644 --- a/src/lm_saes/activation/activation_dataset.py +++ b/src/lm_saes/activation/activation_dataset.py @@ -5,10 +5,9 @@ import os import torch.distributed as dist -from lm_saes.utils.misc import print_once from lm_saes.config import ActivationGenerationConfig from lm_saes.activation.token_source import TokenSource - +from lm_saes.utils.misc import is_master, print_once @torch.no_grad() def make_activation_dataset( @@ -22,19 +21,19 @@ def make_activation_dataset( token_source = TokenSource.from_config(model=model, cfg=cfg.dataset) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): for hook_point in cfg.hook_points: os.makedirs(os.path.join(cfg.activation_save_path, hook_point), exist_ok=False) - if cfg.use_ddp: + if cfg.ddp_size > 1: dist.barrier() - total_generating_tokens = cfg.total_generating_tokens // cfg.world_size + total_generating_tokens = cfg.total_generating_tokens // dist.get_world_size() else: total_generating_tokens = cfg.total_generating_tokens n_tokens = 0 chunk_idx = 0 - pbar = tqdm(total=total_generating_tokens, desc=f"Activation dataset Rank {cfg.rank}" if cfg.use_ddp else "Activation dataset") + pbar = tqdm(total=total_generating_tokens, desc=f"Activation dataset Rank {dist.get_rank()}" if dist.is_initialized() else "Activation dataset") while n_tokens < total_generating_tokens: act_dict = {hook_point: torch.empty((0, cfg.dataset.context_size, cfg.lm.d_model), dtype=cfg.lm.dtype, device=cfg.lm.device) for hook_point in cfg.hook_points} @@ -63,7 +62,7 @@ def make_activation_dataset( "context": context, "position": position, }, - os.path.join(cfg.activation_save_path, hook_point, f"chunk-{str(chunk_idx).zfill(5)}.pt" if not cfg.use_ddp else f"shard-{cfg.rank}-chunk-{str(chunk_idx).zfill(5)}.pt") + os.path.join(cfg.activation_save_path, hook_point, f"chunk-{str(chunk_idx).zfill(5)}.pt" if not dist.is_initialized() else f"shard-{dist.get_rank()}-chunk-{str(chunk_idx).zfill(5)}.pt") ) chunk_idx += 1 diff --git a/src/lm_saes/activation/activation_source.py b/src/lm_saes/activation/activation_source.py index d0b5d40..87e12d7 100644 --- a/src/lm_saes/activation/activation_source.py +++ b/src/lm_saes/activation/activation_source.py @@ -66,8 +66,8 @@ def __init__(self, cfg: ActivationStoreConfig): assert len(cfg.hook_points) == 1, "CachedActivationSource only supports one hook point" self.hook_point = cfg.hook_points[0] self.chunk_paths = list_activation_chunks(cfg.cached_activations_path[0], self.hook_point) - if cfg.use_ddp: - self.chunk_paths = [p for i, p in enumerate(self.chunk_paths) if i % cfg.world_size == cfg.rank] + if cfg.ddp_size > 1: + self.chunk_paths = [p for i, p in enumerate(self.chunk_paths) if i % dist.get_world_size() == dist.get_rank()] random.shuffle(self.chunk_paths) self.token_buffer = torch.empty((0, cfg.dataset.context_size), dtype=torch.long, device=cfg.device) diff --git a/src/lm_saes/activation/activation_store.py b/src/lm_saes/activation/activation_store.py index ed2c49e..49317a2 100644 --- a/src/lm_saes/activation/activation_store.py +++ b/src/lm_saes/activation/activation_store.py @@ -1,29 +1,37 @@ from typing import Callable, Dict, Generator, Iterable import torch +from torch.cuda import is_initialized import torch.distributed as dist import torch.utils.data from transformer_lens import HookedTransformer from lm_saes.config import ActivationStoreConfig -from lm_saes.activation.activation_source import ActivationSource, CachedActivationSource, TokenActivationSource +from lm_saes.activation.activation_source import ( + ActivationSource, + CachedActivationSource, + TokenActivationSource, +) from lm_saes.activation.token_source import TokenSource +import math +from torch.distributed.device_mesh import init_device_mesh +import torch.distributed._functional_collectives as funcol + class ActivationStore: - def __init__(self, - act_source: ActivationSource, - d_model: int, - n_tokens_in_buffer=500000, - device="cuda", - use_ddp=False - ): + + def __init__(self, act_source: ActivationSource, cfg: ActivationStoreConfig): self.act_source = act_source - self.d_model = d_model - self.buffer_size = n_tokens_in_buffer - self.device = device - self.use_ddp = use_ddp + self.buffer_size = cfg.n_tokens_in_buffer + self.device = cfg.device + self.ddp_size = cfg.ddp_size # 1 8 + self.tp_size = cfg.tp_size self._store: Dict[str, torch.Tensor] = {} - + self._all_gather_buffer: Dict[str, torch.Tensor] = {} + self.device_mesh = init_device_mesh( + "cuda", (self.ddp_size, self.tp_size), mesh_dim_names=("ddp", "tp") + ) + def initialize(self): self.refill() self.shuffle() @@ -56,37 +64,71 @@ def __len__(self): def next(self, batch_size) -> Dict[str, torch.Tensor] | None: # Check if the activation store needs to be refilled. - need_refill = torch.tensor([self.__len__() < self.buffer_size // 2], device=self.device, dtype=torch.int) - if self.use_ddp: # When using DDP, we do refills in a synchronized manner to save time + need_refill = torch.tensor( + [self.__len__() < self.buffer_size // 2], + device=self.device, + dtype=torch.int, + ) + if ( + dist.is_initialized() + ): # When using DDP, we do refills in a synchronized manner to save time dist.all_reduce(need_refill, op=dist.ReduceOp.MAX) if need_refill.item() > 0: self.refill() self.shuffle() - if self.use_ddp: # Wait for all processes to refill the store + if dist.is_initialized(): # Wait for all processes to refill the store dist.barrier() + if self.tp_size > 1: + for k, v in self._store.items(): + if k not in self._all_gather_buffer: + self._all_gather_buffer[k] = torch.empty(size=(0,), dtype=v.dtype, device=self.device) + + gather_len = math.ceil( + (batch_size - self._all_gather_buffer[k].size(0)) / self.tp_size + ) + + assert gather_len <= v.size(0), "Not enough activations in the store" + gather_tensor = funcol.all_gather_tensor( + v[:gather_len], gather_dim=0, group=self.device_mesh["tp"] + ) + self._store[k] = v[gather_len:] + + self._all_gather_buffer[k] = torch.cat( + [self._all_gather_buffer[k], gather_tensor], dim=0 + ) + + ret = {k: self._all_gather_buffer[k][:batch_size] for k in self._store} + for k in self._store: + self._all_gather_buffer[k] = self._all_gather_buffer[k][batch_size:] + return ret if len(ret) > 0 else None + + else: + + ret = {k: v[:batch_size] for k, v in self._store.items()} + for k in self._store: + self._store[k] = self._store[k][batch_size:] + return ret if len(ret) > 0 else None - ret = {k: v[:batch_size] for k, v in self._store.items()} - for k in self._store: - self._store[k] = self._store[k][batch_size:] - return ret if len(ret) > 0 else None - def next_tokens(self, batch_size: int) -> torch.Tensor | None: - return self.act_source.next_tokens(batch_size) - + if self.tp_size > 1: + # TODO + next_tokens = self.act_source.next_tokens(batch_size) + funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"]) + return next_tokens + else: + return self.act_source.next_tokens(batch_size) + @staticmethod def from_config(model: HookedTransformer, cfg: ActivationStoreConfig): act_source: ActivationSource if cfg.use_cached_activations: - act_source=CachedActivationSource(cfg=cfg) + act_source = CachedActivationSource(cfg=cfg) else: - act_source=TokenActivationSource( + act_source = TokenActivationSource( model=model, cfg=cfg, ) return ActivationStore( act_source=act_source, - d_model=cfg.lm.d_model, - n_tokens_in_buffer=cfg.n_tokens_in_buffer, - device=cfg.device, - use_ddp=cfg.use_ddp, - ) \ No newline at end of file + cfg=cfg, + ) diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index f825e54..3b15c55 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -4,10 +4,11 @@ from datasets import load_dataset, load_from_disk, Dataset from transformer_lens import HookedTransformer - +import torch.distributed as dist from lm_saes.config import TextDatasetConfig import random + class TokenSource: def __init__( self, @@ -28,9 +29,13 @@ def __init__( self.data_iter = [iter(dataloader) for dataloader in self.dataloader] - self.token_buffer = torch.empty((0, seq_len), dtype=torch.long, device=self.device) + self.token_buffer = torch.empty( + (0, seq_len), dtype=torch.long, device=self.device + ) - self.bos_token_id_tensor = torch.tensor([self.model.tokenizer.bos_token_id], dtype=torch.long, device=self.device) + self.bos_token_id_tensor = torch.tensor( + [self.model.tokenizer.bos_token_id], dtype=torch.long, device=self.device + ) self.resid = torch.tensor([], dtype=torch.long, device=self.device) self.sample_probs = sample_probs @@ -49,31 +54,43 @@ def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None: cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.eos_token_id] cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.pad_token_id] - self.resid = torch.cat([self.resid, self.bos_token_id_tensor.clone(), cur_tokens], dim=0) + self.resid = torch.cat( + [self.resid, self.bos_token_id_tensor.clone(), cur_tokens], dim=0 + ) while self.resid.size(0) >= self.seq_len: - self.token_buffer = torch.cat([self.token_buffer, self.resid[:self.seq_len].unsqueeze(0)], dim=0) - self.resid = self.resid[self.seq_len:] - self.resid = torch.cat([self.bos_token_id_tensor.clone(), self.resid], dim=0) + self.token_buffer = torch.cat( + [self.token_buffer, self.resid[: self.seq_len].unsqueeze(0)], + dim=0, + ) + self.resid = self.resid[self.seq_len :] + self.resid = torch.cat( + [self.bos_token_id_tensor.clone(), self.resid], dim=0 + ) tokens = tokens[1:] else: - tokens = tokens[:, :self.seq_len] + tokens = tokens[:, : self.seq_len] if tokens.size(1) < self.seq_len: pad_len = self.seq_len - tokens.size(1) tokens = torch.cat([tokens, torch.full((tokens.size(0), pad_len), self.model.tokenizer.pad_token_id, dtype=torch.long, device=self.device)], dim=1) self.token_buffer = torch.cat([self.token_buffer, tokens], dim=0) - def reset_iter(self, empty_idx: int): - self.data_iter = self.data_iter[:empty_idx] + self.data_iter[empty_idx + 1:] + self.data_iter = self.data_iter[:empty_idx] + self.data_iter[empty_idx + 1 :] - self.sample_probs = self.sample_probs[:empty_idx] + self.sample_probs[empty_idx + 1:] + self.sample_probs = ( + self.sample_probs[:empty_idx] + self.sample_probs[empty_idx + 1 :] + ) - self.sample_probs = [prob / sum(self.sample_probs) for prob in self.sample_probs] + self.sample_probs = [ + prob / sum(self.sample_probs) for prob in self.sample_probs + ] def next(self, batch_size: int) -> torch.Tensor | None: while self.token_buffer.size(0) < batch_size: - dataset_idx_to_fetch = random.choices(range(len(self.dataloader)), weights=self.sample_probs)[0] + dataset_idx_to_fetch = random.choices( + range(len(self.dataloader)), weights=self.sample_probs + )[0] try: batch = next(self.data_iter[dataset_idx_to_fetch]) except StopIteration: @@ -96,24 +113,24 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig): dataset = load_dataset(dataset_path, split="train", cache_dir=cfg.cache_dir) else: dataset = load_from_disk(dataset_path) - if cfg.use_ddp: - shard_id = cfg.rank - shard = dataset.shard(num_shards=cfg.world_size, index=shard_id) + if dist.is_initialized(): + shard_id = dist.get_rank() + shard = dataset.shard( + num_shards=dist.get_world_size(), index=shard_id + ) else: shard = dataset - if cfg.use_ddp: - shard_id = cfg.rank - shard = dataset.shard(num_shards=cfg.world_size, index=shard_id) - else: - shard = dataset dataloader = DataLoader(shard, batch_size=cfg.store_batch_size) return dataloader @staticmethod def from_config(model: HookedTransformer, cfg: TextDatasetConfig): - dataloader = [TokenSource._process_dataset(dataset_path, cfg) for dataset_path in cfg.dataset_path] + dataloader = [ + TokenSource._process_dataset(dataset_path, cfg) + for dataset_path in cfg.dataset_path + ] return TokenSource( dataloader=dataloader, diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index e31212c..0c2dc18 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -10,9 +10,10 @@ from lm_saes.utils.config import FlattenableModel from lm_saes.utils.huggingface import parse_pretrained_name_or_path -from lm_saes.utils.misc import convert_str_to_torch_dtype, print_once +from lm_saes.utils.misc import convert_str_to_torch_dtype, print_once, is_master from transformer_lens.loading_from_pretrained import get_official_model_name +from torch.distributed.device_mesh import DeviceMesh @dataclass(kw_only=True) @@ -20,6 +21,7 @@ class BaseConfig(FlattenableModel): def __post_init__(self): pass + @dataclass(kw_only=True) class BaseModelConfig(BaseConfig): device: str = "cpu" @@ -30,37 +32,32 @@ def to_dict(self) -> Dict[str, Any]: return { field.name: getattr(self, field.name) for field in fields(self) - if field.name not in [base_field.name for base_field in fields(BaseModelConfig)] + if field.name + not in [base_field.name for base_field in fields(BaseModelConfig)] } - + @classmethod def from_dict(cls, d: Dict[str, Any], **kwargs): d = {k: v for k, v in d.items() if k in [field.name for field in fields(cls)]} return cls(**d, **kwargs) - + def __post_init__(self): super().__post_init__() if isinstance(self.dtype, str): self.dtype = convert_str_to_torch_dtype(self.dtype) - + if dist.is_initialized() and self.device == "cuda": + self.device = f"cuda:{dist.get_rank()}" + + @dataclass(kw_only=True) class RunnerConfig(BaseConfig): - use_ddp: bool = False - exp_name: str = "test" exp_series: Optional[str] = None exp_result_dir: str = "results" def __post_init__(self): super().__post_init__() - # Set rank, world_size, and device if using DDP - if self.use_ddp: - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - # if isinstance(self, BaseModelConfig): - # self.device = f"cuda:{self.rank}" - - if not self.use_ddp or self.rank == 0: + if is_master(): os.makedirs(self.exp_result_dir, exist_ok=True) os.makedirs(os.path.join(self.exp_result_dir, self.exp_name), exist_ok=True) @@ -92,18 +89,20 @@ def from_pretrained_sae(pretrained_name_or_path: str, **kwargs): return LanguageModelConfig.from_dict(lm_config, **kwargs) def save_lm_config(self, sae_path: str): - assert os.path.exists(sae_path), f"{sae_path} does not exist. Unable to save LanguageModelConfig." + assert os.path.exists( + sae_path + ), f"{sae_path} does not exist. Unable to save LanguageModelConfig." with open(os.path.join(sae_path, "lm_config.json"), "w") as f: - json.dump(self.to_dict(), f, indent=4) + json.dump(self.to_dict(), f, indent=4) @dataclass(kw_only=True) class TextDatasetConfig(RunnerConfig): - dataset_path: List[str] = 'openwebtext' # type: ignore + dataset_path: List[str] = "openwebtext" # type: ignore cache_dir: Optional[str] = None is_dataset_tokenized: bool = False is_dataset_on_disk: bool = False - concat_tokens: List[bool] = False # type: ignore + concat_tokens: List[bool] = False # type: ignore context_size: int = 128 store_batch_size: int = 64 sample_probs: List[float] = field(default_factory=lambda: [1.0]) @@ -135,9 +134,11 @@ class ActivationStoreConfig(BaseModelConfig, RunnerConfig): """ Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. """ use_cached_activations: bool = False - cached_activations_path: List[str] = None # type: ignore + cached_activations_path: List[str] = None # type: ignore n_tokens_in_buffer: int = 500_000 + tp_size: int = 1 + ddp_size: int = 1 def __post_init__(self): super().__post_init__() @@ -156,14 +157,16 @@ class WandbConfig(BaseConfig): exp_name: Optional[str] = None wandb_entity: Optional[str] = None + @dataclass(kw_only=True) class SAEConfig(BaseModelConfig): """ Configuration for training or running a sparse autoencoder. """ + hook_point_in: str = "blocks.0.hook_resid_pre" """ The hook point to use as input to the SAE. """ - hook_point_out: str = None # type: ignore + hook_point_out: str = None # type: ignore """ The hook point to use as label of the SAE. If None, it will be set to hook_point_in. """ sae_pretrained_name_or_path: Optional[str] = None @@ -190,6 +193,9 @@ class SAEConfig(BaseModelConfig): use_ghost_grads: bool = False + tp_size: int = 1 + ddp_size: int = 1 + def __post_init__(self): super().__post_init__() if self.hook_point_out is None: @@ -207,9 +213,11 @@ def __post_init__(self): @staticmethod - def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, **kwargs): + def from_pretrained( + pretrained_name_or_path: str, strict_loading: bool = True, **kwargs + ): """Load the SAEConfig from a pretrained SAE name or path. Config is read from /hyperparams.json. - + Args: sae_path (str): The path to the pretrained SAE. **kwargs: Additional keyword arguments to pass to the SAEConfig constructor. @@ -220,7 +228,7 @@ def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, * sae_config["sae_pretrained_name_or_path"] = pretrained_name_or_path sae_config["strict_loading"] = strict_loading return SAEConfig.from_dict(sae_config, **kwargs) - + @deprecated("Use from_pretrained and to_dict instead.") @staticmethod def get_hyperparameters( @@ -239,21 +247,25 @@ def get_hyperparameters( if k in SAEConfig.__dataclass_fields__.keys() } return hyperparams - + def save_hyperparameters(self, sae_path: str, remove_loading_info: bool = True): - assert os.path.exists(sae_path), f"{sae_path} does not exist. Unable to save hyperparameters." + assert os.path.exists( + sae_path + ), f"{sae_path} does not exist. Unable to save hyperparameters." d = self.to_dict() if remove_loading_info: d.pop("sae_pretrained_name_or_path", None) d.pop("strict_loading", None) with open(os.path.join(sae_path, "hyperparams.json"), "w") as f: json.dump(d, f, indent=4) - + + @dataclass(kw_only=True) class OpenAIConfig(BaseConfig): openai_api_key: str openai_base_url: str + @dataclass(kw_only=True) class AutoInterpConfig(BaseConfig): sae: SAEConfig @@ -272,6 +284,7 @@ class LanguageModelSAERunnerConfig(RunnerConfig): act_store: ActivationStoreConfig wandb: WandbConfig + @dataclass(kw_only=True) class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): """ @@ -311,7 +324,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): def __post_init__(self): super().__post_init__() - if not self.use_ddp or self.rank == 0: + if is_master(): if os.path.exists( os.path.join(self.exp_result_dir, self.exp_name, "checkpoints") ): @@ -320,17 +333,13 @@ def __post_init__(self): ) os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints")) - - self.effective_batch_size = ( - self.train_batch_size * self.world_size - if self.use_ddp - else self.train_batch_size - ) + self.effective_batch_size = self.train_batch_size * self.sae.ddp_size print_once(f"Effective batch size: {self.effective_batch_size}") total_training_steps = self.total_training_tokens // self.effective_batch_size print_once(f"Total training steps: {total_training_steps}") + @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): """ @@ -347,8 +356,11 @@ class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): def __post_init__(self): super().__post_init__() - if not self.use_ddp or self.rank == 0: - os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), exist_ok=True) + if is_master(): + os.makedirs( + os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), + exist_ok=True, + ) @dataclass(kw_only=True) @@ -358,24 +370,30 @@ class ActivationGenerationConfig(RunnerConfig): hook_points: list[str] = field(default_factory=list) - activation_save_path: str = None # type: ignore + activation_save_path: str = None # type: ignore total_generating_tokens: int = 300_000_000 chunk_size: int = int(0.5 * 2**30) # 0.5 GB + ddp_size: int = 1 + tp_size: int = 1 def __post_init__(self): super().__post_init__() if self.activation_save_path is None: - assert isinstance(self.dataset_path, list) and len(self.dataset_path) == 1, "Only one dataset path is supported for activation generation." + assert ( + isinstance(self.dataset_path, list) and len(self.dataset_path) == 1 + ), "Only one dataset path is supported for activation generation." self.activation_save_path = f"activations/{self.dataset_path[0].split('/')[-1]}/{self.model_name.replace('/', '_')}_{self.context_size}" os.makedirs(self.activation_save_path, exist_ok=True) + @dataclass(kw_only=True) class MongoConfig(BaseConfig): mongo_uri: str = "mongodb://localhost:27017" mongo_db: str = "mechinterp" + @dataclass(kw_only=True) class LanguageModelSAEAnalysisConfig(RunnerConfig): """ @@ -392,13 +410,21 @@ class LanguageModelSAEAnalysisConfig(RunnerConfig): False # If True, we will sample the activations based on weights. Otherwise, top n_samples activations will be used. ) sample_weight_exponent: float = 2.0 - subsample: Dict[str, Dict[str, Any]] = field(default_factory=lambda: { "top_activations": {"proportion": 1.0, "n_samples": 10} }) + subsample: Dict[str, Dict[str, Any]] = field( + default_factory=lambda: { + "top_activations": {"proportion": 1.0, "n_samples": 10} + } + ) - n_sae_chunks: int = 1 # Number of chunks to split the SAE into for analysis. For large models and SAEs, this can be useful to avoid memory issues. + n_sae_chunks: int = ( + 1 # Number of chunks to split the SAE into for analysis. For large models and SAEs, this can be useful to avoid memory issues. + ) def __post_init__(self): super().__post_init__() - assert self.sae.d_sae % self.n_sae_chunks == 0, f"d_sae ({self.sae.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})" + assert ( + self.sae.d_sae % self.n_sae_chunks == 0 + ), f"d_sae ({self.sae.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})" @dataclass(kw_only=True) diff --git a/src/lm_saes/entrypoint.py b/src/lm_saes/entrypoint.py index 91d7278..012ea20 100644 --- a/src/lm_saes/entrypoint.py +++ b/src/lm_saes/entrypoint.py @@ -48,13 +48,15 @@ def entrypoint(): from lm_saes.config import SAEConfig config['sae'] = SAEConfig.from_pretrained(args.sae).to_dict() - use_ddp = "use_ddp" in config and config["use_ddp"] - if use_ddp: + tp_size = config.get('tp_size', 1) + ddp_size = config.get('ddp_size', 1) + if tp_size > 1 or ddp_size > 1: import os import torch.distributed as dist os.environ["TOKENIZERS_PARALLELISM"] = "false" dist.init_process_group(backend='nccl') torch.cuda.set_device(dist.get_rank()) + if args.runner == SupportedRunner.TRAIN: from lm_saes.runner import language_model_sae_runner @@ -79,5 +81,6 @@ def entrypoint(): else: raise ValueError(f'Unsupported runner: {args.runner}.') - if use_ddp: + if dist.is_initialized(): dist.destroy_process_group() + diff --git a/src/lm_saes/evals.py b/src/lm_saes/evals.py index 75a2e42..425c034 100644 --- a/src/lm_saes/evals.py +++ b/src/lm_saes/evals.py @@ -10,8 +10,10 @@ from lm_saes.sae import SparseAutoEncoder from lm_saes.activation.activation_store import ActivationStore + # from lm_saes.activation_store_theirs import ActivationStoreTheirs from lm_saes.config import LanguageModelSAERunnerConfig +from lm_saes.utils.misc import is_master @torch.no_grad() def run_evals( @@ -63,14 +65,15 @@ def run_evals( l2_norm_in = torch.norm(original_act_out, dim=-1) l2_norm_out = torch.norm(reconstructed, dim=-1) - if cfg.use_ddp: - dist.reduce(l2_norm_in, dst=0, op=dist.ReduceOp.AVG) + if cfg.sae.ddp_size > 1: + dist.reduce( + l2_norm_in, dst=0, op=dist.ReduceOp.AVG + ) dist.reduce(l2_norm_out, dst=0, op=dist.ReduceOp.AVG) l2_norm_ratio = l2_norm_out / l2_norm_in l0 = (feature_acts > 0).float().sum(-1) - # TODO: DDP metrics = { # l2 norms @@ -85,7 +88,7 @@ def run_evals( "metrics/ce_loss_with_ablation": zero_abl_loss, } - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb.log( metrics, step=n_training_steps + 1, @@ -93,6 +96,7 @@ def run_evals( return metrics + def recons_loss_batched( model: HookedTransformer, sae: SparseAutoEncoder, @@ -101,15 +105,17 @@ def recons_loss_batched( n_batches: int = 100, ): losses = [] - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar = tqdm(total=n_batches, desc="Evaluation", smoothing=0.01) for _ in range(n_batches): - batch_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size) + batch_tokens = activation_store.next_tokens( + cfg.act_store.dataset.store_batch_size + ) assert batch_tokens is not None, "Not enough tokens in the store" score, loss, recons_loss, zero_abl_loss = get_recons_loss( model, sae, cfg, batch_tokens ) - if cfg.use_ddp: + if cfg.sae.ddp_size > 1: dist.reduce(score, dst=0, op=dist.ReduceOp.AVG) dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG) dist.reduce(recons_loss, dst=0, op=dist.ReduceOp.AVG) @@ -122,10 +128,10 @@ def recons_loss_batched( zero_abl_loss.mean().item(), ) ) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar.update(1) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar.close() losses = pd.DataFrame( @@ -151,12 +157,17 @@ def get_recons_loss( names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out, ) - activations_in, activations_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] - replacements = sae.forward(activations_in, label=activations_out).to(activations_out.dtype) + activations_in, activations_out = ( + cache[cfg.sae.hook_point_in], + cache[cfg.sae.hook_point_out], + ) + replacements = sae.forward(activations_in, label=activations_out).to( + activations_out.dtype + ) def replacement_hook(activations: torch.Tensor, hook: Any): return replacements - + recons_loss: torch.Tensor = model.run_with_hooks( batch_tokens, return_type="loss", diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index f461b85..8e9a05c 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -1,12 +1,13 @@ from typing import Any, cast import os +from pandas.core.algorithms import rank import wandb from dataclasses import asdict import torch - +import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer from transformer_lens import HookedTransformer, HookedTransformerConfig @@ -29,10 +30,11 @@ from lm_saes.analysis.sample_feature_activations import sample_feature_activations from lm_saes.analysis.features_to_logits import features_to_logits from torch.nn.parallel import DistributedDataParallel as DDP +from lm_saes.utils.misc import is_master def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - if (not cfg.use_ddp) or cfg.rank == 0: + if is_master(): cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) sae = SparseAutoEncoder.from_config(cfg=cfg.sae) @@ -61,7 +63,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): use_fast=True, add_bos_token=True, ) - + model = HookedTransformer.from_pretrained( cfg.lm.model_name, use_flash_attn=cfg.lm.use_flash_attn, @@ -71,9 +73,6 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): tokenizer=hf_tokenizer, dtype=cfg.lm.dtype, ) - if cfg.use_ddp: - _ = DDP(model, device_ids=[cfg.rank]) - _ = DDP(sae, device_ids=[cfg.rank]) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) @@ -96,7 +95,26 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if ( + cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None + or cfg.sae.init_decoder_norm is None + ): + assert not cfg.finetuning + sae = SparseAutoEncoder.from_initialization_searching( + activation_store=activation_store, + cfg=cfg, + ) + else: + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) + + if cfg.finetuning: + # Fine-tune SAE with frozen encoder weights and bias + sae.train_finetune_for_suppression_parameters() + + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + + if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { **asdict(cfg), **asdict(cfg.sae), @@ -124,7 +142,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): cfg, ) - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb.finish() return sae @@ -164,7 +182,7 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): ) model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { **asdict(cfg), **asdict(cfg.sae), @@ -192,11 +210,11 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): result = run_evals(model, sae, activation_store, cfg, 0) # Print results in tabular format - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): for key, value in result.items(): print(f"{key}: {value}") - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb.finish() @@ -234,7 +252,7 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): model.eval() activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb_config: dict = { **asdict(cfg), **asdict(cfg.sae), @@ -256,11 +274,11 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): result = run_evals(model, sae, activation_store, cfg, 0) # Print results in tabular format - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): for key, value in result.items(): print(f"{key}: {value}") - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb.finish() return sae @@ -338,7 +356,9 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): for chunk_id in range(cfg.n_sae_chunks): activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) - result = sample_feature_activations(sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks) + result = sample_feature_activations( + sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks + ) for i in range(len(result["index"].cpu().numpy().tolist())): client.update_feature( diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 9db2b3c..3232723 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -2,6 +2,8 @@ import os from typing import Dict, Literal, Union, overload, List import torch +from torch.distributed.device_mesh import init_device_mesh +import torch.nn as nn import math from einops import einsum from jaxtyping import Float @@ -12,19 +14,18 @@ from lm_saes.config import SAEConfig, LanguageModelSAETrainingConfig from lm_saes.activation.activation_store import ActivationStore from lm_saes.utils.huggingface import parse_pretrained_name_or_path +import torch.distributed._functional_collectives as funcol + class SparseAutoEncoder(HookedRootModule): """Sparse AutoEncoder model. An autoencoder model that learns to compress the input activation tensor into a high-dimensional but sparse feature activation tensor. - + Can also act as a transcoder model, which learns to compress the input activation tensor into a feature activation tensor, and then reconstruct a label activation tensor from the feature activation tensor. """ - def __init__( - self, - cfg: SAEConfig - ): + def __init__(self, cfg: SAEConfig): """Initialize the SparseAutoEncoder model. Args: @@ -36,24 +37,39 @@ def __init__( self.cfg = cfg self.current_l1_coefficient = cfg.l1_coefficient - self.encoder = torch.nn.Parameter(torch.empty((cfg.d_model, cfg.d_sae), dtype=cfg.dtype, device=cfg.device)) + self.encoder = torch.nn.Linear( + cfg.d_model, cfg.d_sae, bias=True, device=cfg.device, dtype=cfg.dtype + ) + torch.nn.init.kaiming_uniform_(self.encoder.weight) + torch.nn.init.zeros_(self.encoder.bias) + self.device_mesh = init_device_mesh( + "cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp") + ) if cfg.use_glu_encoder: - self.encoder_glu = torch.nn.Parameter(torch.empty((cfg.d_model, cfg.d_sae), dtype=cfg.dtype, device=cfg.device)) - - self.encoder_bias_glu = torch.nn.Parameter(torch.empty((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - - self.feature_act_mask = torch.nn.Parameter(torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - self.feature_act_scale = torch.nn.Parameter(torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) - self.decoder = torch.nn.Parameter(torch.empty((cfg.d_sae, cfg.d_model), dtype=cfg.dtype, device=cfg.device)) - - - if cfg.use_decoder_bias: - self.decoder_bias = torch.nn.Parameter(torch.empty((cfg.d_model,), dtype=cfg.dtype, device=cfg.device)) + self.encoder_glu = torch.nn.Linear( + cfg.d_model, cfg.d_sae, bias=True, device=cfg.device, dtype=cfg.dtype + ) + torch.nn.init.kaiming_uniform_(self.encoder_glu.weight) + torch.nn.init.zeros_(self.encoder_glu.bias) - self.encoder_bias = torch.nn.Parameter(torch.empty((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device)) + self.feature_act_mask = torch.nn.Parameter( + torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device) + ) + self.feature_act_scale = torch.nn.Parameter( + torch.ones((cfg.d_sae,), dtype=cfg.dtype, device=cfg.device) + ) + self.decoder = torch.nn.Linear( + cfg.d_sae, + cfg.d_model, + bias=cfg.use_decoder_bias, + device=cfg.device, + dtype=cfg.dtype, + ) + torch.nn.init.kaiming_uniform_(self.decoder.weight) + self.set_decoder_norm_to_unit_norm() self.train_base_parameters() @@ -85,38 +101,35 @@ def initialize_parameters(self): def train_base_parameters(self): - """Set the base parameters to be trained. - """ + """Set the base parameters to be trained.""" base_parameters = [ - self.encoder, - self.decoder, - self.encoder_bias, + self.encoder.weight, + self.decoder.weight, + self.encoder.bias, ] if self.cfg.use_glu_encoder: - base_parameters.extend([self.encoder_glu, self.encoder_bias_glu]) + base_parameters.extend([self.encoder_glu.weight, self.encoder_glu.bias]) if self.cfg.use_decoder_bias: - base_parameters.append(self.decoder_bias) + base_parameters.append(self.decoder.bias) for p in self.parameters(): p.requires_grad_(False) for p in base_parameters: p.requires_grad_(True) def train_finetune_for_suppression_parameters(self): - """Set the parameters to be trained for feature suppression. - """ + """Set the parameters to be trained for feature suppression.""" finetune_for_suppression_parameters = [ self.feature_act_scale, - self.decoder, + self.decoder.weight, ] if self.cfg.use_decoder_bias: - finetune_for_suppression_parameters.append(self.decoder_bias) + finetune_for_suppression_parameters.append(self.decoder.bias) for p in self.parameters(): p.requires_grad_(False) for p in finetune_for_suppression_parameters: p.requires_grad_(True) - def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> float | torch.Tensor: """Compute the normalization factor for the activation vectors. @@ -132,29 +145,80 @@ def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> float | torch return math.sqrt(self.cfg.d_model) / self.cfg.dataset_average_activation_norm[hook_point] else: return torch.tensor(1.0, dtype=self.cfg.dtype, device=self.cfg.device) - + @overload def encode( self, - x: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]], - label: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]] | None = None, - return_hidden_pre: Literal[False] = False - ) -> Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]: ... + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + label: ( + Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ] + | None + ) = None, + return_hidden_pre: Literal[False] = False, + ) -> Union[ + Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"] + ]: ... @overload def encode( self, - x: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]], - label: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]] | None, - return_hidden_pre: Literal[True] - ) -> tuple[Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]], Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]]: ... + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + label: ( + Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ] + | None + ), + return_hidden_pre: Literal[True], + ) -> tuple[ + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + ]: ... def encode( - self, - x: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]], - label: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]] | None = None, - return_hidden_pre: bool = False - ) -> Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"], tuple[Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]], Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]]]: + self, + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + label: ( + Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ] + | None + ) = None, + return_hidden_pre: bool = False, + ) -> Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + tuple[ + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + ], + ]: """Encode the model activation x into feature activations. Args: @@ -171,39 +235,42 @@ def encode( label = x if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: - x = x - self.decoder_bias + x = x - self.decoder.bias x = x * self.compute_norm_factor(x, hook_point='in') - hidden_pre = einsum( - x, - self.encoder, - "... d_model, d_model d_sae -> ... d_sae", - ) + self.encoder_bias + hidden_pre = self.encoder(x) if self.cfg.use_glu_encoder: - hidden_pre_glu = einsum( - x, - self.encoder_glu, - "... d_model, d_model d_sae -> ... d_sae", - ) + self.encoder_bias_glu - hidden_pre_glu = torch.sigmoid(hidden_pre_glu) + hidden_pre_glu = torch.sigmoid(self.encoder_glu(x)) + hidden_pre = hidden_pre * hidden_pre_glu hidden_pre = hidden_pre / self.compute_norm_factor(label, hook_point='in') hidden_pre = self.hook_hidden_pre(hidden_pre) - feature_acts = self.feature_act_mask * self.feature_act_scale * torch.clamp(hidden_pre, min=0.0) + feature_acts = ( + self.feature_act_mask + * self.feature_act_scale + * torch.clamp(hidden_pre, min=0.0) + ) + feature_acts = self.hook_feature_acts(feature_acts) if return_hidden_pre: return feature_acts, hidden_pre return feature_acts - + def decode( self, - feature_acts: Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]], - ) -> Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]]: + feature_acts: Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + ) -> Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ]: """Decode the feature activations into the reconstructed model activation in the label space. Args: @@ -213,25 +280,33 @@ def decode( torch.Tensor: The reconstructed model activation. Not normalized. """ - reconstructed = einsum( - feature_acts, - self.decoder, - "... d_sae, d_sae d_model -> ... d_model", - ) - if self.cfg.use_decoder_bias: - reconstructed = reconstructed + self.decoder_bias - + reconstructed = self.decoder(feature_acts) reconstructed = self.hook_reconstructed(reconstructed) return reconstructed - + def compute_loss( self, - x: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]], + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], dead_feature_mask: Float[torch.Tensor, "d_sae"] | None = None, - label: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]] | None = None, - return_aux_data: bool = True - ) -> Union[Float[torch.Tensor, "batch"], tuple[Float[torch.Tensor, "batch"], tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]]: + label: ( + Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ] + | None + ) = None, + return_aux_data: bool = True, + ) -> Union[ + Float[torch.Tensor, "batch"], + tuple[ + Float[torch.Tensor, "batch"], + tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], + ], + ]: """Compute the loss of the model. Args: @@ -258,7 +333,9 @@ def compute_loss( label_normed = label * label_norm_factor # l_rec: (batch, d_model) - l_rec = (reconstructed_normed - label_normed).pow(2) / (label_normed - label_normed.mean(dim=0, keepdim=True)).pow(2).sum(dim=-1, keepdim=True).clamp(min=1e-8).sqrt() + l_rec = (reconstructed_normed - label_normed).pow(2) / ( + label_normed - label_normed.mean(dim=0, keepdim=True) + ).pow(2).sum(dim=-1, keepdim=True).clamp(min=1e-8).sqrt() # l_l1: (batch,) if self.cfg.sparsity_include_decoder_norm: @@ -284,7 +361,9 @@ def compute_loss( # 2. feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_feature_mask]) - ghost_out = feature_acts_dead_neurons_only @ self.decoder[dead_feature_mask, :] + ghost_out = ( + feature_acts_dead_neurons_only @ self.decoder.weight[dead_feature_mask, :] + ) l2_norm_ghost_out = torch.norm(ghost_out, dim=-1) norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2) ghost_out = ghost_out * norm_scaling_factor[:, None].detach() @@ -305,18 +384,31 @@ def compute_loss( "reconstructed": reconstructed, "hidden_pre": hidden_pre, } - return loss, ({"l_rec": l_rec, "l_l1": l_l1, "l_ghost_resid": l_ghost_resid}, aux_data) - - return loss + return loss, ( + {"l_rec": l_rec, "l_l1": l_l1, "l_ghost_resid": l_ghost_resid}, + aux_data, + ) + return loss def forward( self, - x: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]], - label: Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]] | None = None, - ) -> Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]]: - """Encode and then decode the input activation tensor, outputting the reconstructed activation tensor. - """ + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + label: ( + Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ] + | None + ) = None, + ) -> Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ]: + """Encode and then decode the input activation tensor, outputting the reconstructed activation tensor.""" if label is None: label = x @@ -336,14 +428,14 @@ def update_l1_coefficient(self, training_step): def set_decoder_norm_to_fixed_norm(self, value: float | None = 1.0, force_exact: bool | None = None): if value is None: return - decoder_norm = torch.norm(self.decoder, dim=1, keepdim=True) + decoder_norm = torch.norm(self.decoder.weight, dim=1, keepdim=True) if force_exact is None: force_exact = self.cfg.decoder_exactly_fixed_norm if force_exact: - self.decoder.data = self.decoder.data * value / decoder_norm + self.decoder.data = self.decoder.weight.data * value / decoder_norm else: # Set the norm of the decoder to not exceed value - self.decoder.data = self.decoder.data * value / torch.clamp(decoder_norm, min=value) + self.decoder.weight.data = self.decoder.weight.data * value / torch.clamp(decoder_norm, min=value) @torch.no_grad() def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): @@ -352,8 +444,8 @@ def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0): if value is None: print(f'Encoder norm is not set to a fixed value, using random initialization.') return - encoder_norm = torch.norm(self.encoder, dim=0, keepdim=True) # [1, d_sae] - self.encoder.data = self.encoder.data * value / encoder_norm + encoder_norm = torch.norm(self.encoder.weight, dim=0, keepdim=True) # [1, d_sae] + self.encoder.data = self.encoder.weight.data * value / encoder_norm @torch.no_grad() @@ -367,11 +459,11 @@ def transform_to_unit_decoder_norm(self): if self.cfg.use_glu_encoder: raise NotImplementedError("GLU encoder not supported") - decoder_norm = torch.norm(self.decoder, p=2, dim=1) # (d_sae,) - self.encoder.data = self.encoder.data * decoder_norm - self.decoder.data = self.decoder.data / decoder_norm[:, None] + decoder_norm = torch.norm(self.decoder.weight, p=2, dim=1) # (d_sae,) + self.encoder.data = self.encoder.weight.data * decoder_norm + self.decoder.data = self.decoder.weight.data / decoder_norm[:, None] - self.encoder_bias.data = self.encoder_bias.data * decoder_norm + self.encoder.bias.data = self.encoder.bias.data * decoder_norm @torch.no_grad() @@ -382,41 +474,45 @@ def remove_gradient_parallel_to_decoder_directions(self): """ parallel_component = einsum( - self.decoder.grad, - self.decoder.data, + self.decoder.weight.grad, + self.decoder.weight.data, "d_sae d_model, d_sae d_model -> d_sae", ) - assert self.decoder.grad is not None, "No gradient to remove parallel component from" + assert ( + self.decoder.weight.grad is not None + ), "No gradient to remove parallel component from" - self.decoder.grad -= einsum( + self.decoder.weight.grad -= einsum( parallel_component, - self.decoder.data, + self.decoder.weight.data, "d_sae, d_sae d_model -> d_sae d_model", ) - + @torch.no_grad() def compute_thomson_potential(self): - dist = torch.cdist(self.decoder, self.decoder, p=2).flatten()[1:].view(self.cfg.d_sae - 1, self.cfg.d_sae + 1)[:, :-1] + dist = ( + torch.cdist(self.decoder.weight, self.decoder.weight, p=2) + .flatten()[1:] + .view(self.cfg.d_sae - 1, self.cfg.d_sae + 1)[:, :-1] + ) mean_thomson_potential = (1 / dist).mean() return mean_thomson_potential - + @staticmethod - def from_config( - cfg: SAEConfig - ) -> "SparseAutoEncoder": + def from_config(cfg: SAEConfig) -> "SparseAutoEncoder": """Load the SparseAutoEncoder model from the pretrained configuration. Args: cfg (SAEConfig): The configuration of the model, containing the sae_pretrained_name_or_path. - + Returns: SparseAutoEncoder: The pretrained SparseAutoEncoder model. """ pretrained_name_or_path = cfg.sae_pretrained_name_or_path if pretrained_name_or_path is None: return SparseAutoEncoder(cfg) - + path = parse_pretrained_name_or_path(pretrained_name_or_path) if path.endswith(".pt") or path.endswith(".safetensors"): @@ -434,13 +530,15 @@ def from_config( if os.path.exists(ckpt_path): break else: - raise FileNotFoundError(f"Pretrained model not found at {pretrained_name_or_path}") - + raise FileNotFoundError( + f"Pretrained model not found at {pretrained_name_or_path}" + ) + if ckpt_path.endswith(".safetensors"): state_dict = safe.load_file(ckpt_path, device=cfg.device) else: state_dict = torch.load(ckpt_path, map_location=cfg.device)["sae"] - + model = SparseAutoEncoder(cfg) model.load_state_dict(state_dict, strict=cfg.strict_loading) @@ -448,9 +546,7 @@ def from_config( @staticmethod def from_pretrained( - pretrained_name_or_path: str, - strict_loading: bool = True, - **kwargs + pretrained_name_or_path: str, strict_loading: bool = True, **kwargs ) -> "SparseAutoEncoder": """Load the SparseAutoEncoder model from the pretrained configuration. @@ -458,11 +554,13 @@ def from_pretrained( pretrained_name_or_path (str): The name or path of the pretrained model. strict_loading (bool, optional): Whether to load the model strictly. Defaults to True. **kwargs: Additional keyword arguments as BaseModelConfig. - + Returns: SparseAutoEncoder: The pretrained SparseAutoEncoder model. """ - cfg = SAEConfig.from_pretrained(pretrained_name_or_path, strict_loading=strict_loading, **kwargs) + cfg = SAEConfig.from_pretrained( + pretrained_name_or_path, strict_loading=strict_loading, **kwargs + ) return SparseAutoEncoder.from_config(cfg) @@ -525,17 +623,21 @@ def save_pretrained( if os.path.isdir(ckpt_path): ckpt_path = os.path.join(ckpt_path, "sae_weights.safetensors") if ckpt_path.endswith(".safetensors"): - safe.save_file(self.state_dict(), ckpt_path, {"version": version("lm-saes")}) + safe.save_file( + self.state_dict(), ckpt_path, {"version": version("lm-saes")} + ) elif ckpt_path.endswith(".pt"): - torch.save({"sae": self.state_dict(), "version": version("lm-saes")}, ckpt_path) + torch.save( + {"sae": self.state_dict(), "version": version("lm-saes")}, ckpt_path + ) else: raise ValueError(f"Invalid checkpoint path {ckpt_path}. Currently only supports .safetensors and .pt formats.") @property def decoder_norm(self): - return torch.norm(self.decoder, p=2, dim=1).mean() + return torch.norm(self.decoder.weight, p=2, dim=1).mean() @property def encoder_norm(self): - return torch.norm(self.encoder, p=2, dim=0).mean() + return torch.norm(self.encoder.weight, p=2, dim=0).mean() diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 2703843..5752eb3 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -19,7 +19,21 @@ from lm_saes.config import LanguageModelSAEPruningConfig, LanguageModelSAETrainingConfig from lm_saes.optim import get_scheduler from lm_saes.evals import run_evals -from lm_saes.utils.misc import print_once +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, + loss_parallel, +) +from torch.distributed._tensor import ( + DTensor, + Shard, + Replicate, + distribute_module, + distribute_tensor, +) +from lm_saes.utils.misc import is_master, print_once def train_sae( @@ -44,7 +58,7 @@ def train_sae( range(0, total_training_tokens, total_training_tokens // cfg.n_checkpoints) )[1:] activation_store.initialize() - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): print(f"Activation Store Initialized.") # Initialize the SAE decoder bias if necessary # if cfg.use_decoder_bias and (not cfg.use_ddp or cfg.rank == 0): @@ -58,6 +72,23 @@ def train_sae( ) n_frac_active_tokens = torch.tensor([0], device=cfg.sae.device, dtype=torch.int) + if cfg.sae.tp_size > 1: + plan = { + "encoder": ColwiseParallel(output_layouts=Replicate()), + "decoder": RowwiseParallel(input_layouts=Replicate()), + } + if cfg.sae.use_glu_encoder: + plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) + sae = parallelize_module( + sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan + ) + + elif cfg.sae.ddp_size > 1: + _ = DDP(sae, device_mesh=sae.device_mesh["ddp"]) + # sae = parallelize_module( + # sae, device_mesh=sae.device_mesh["ddp"], parallelize_plan={} + # ) + optimizer = Adam(sae.parameters(), lr=cfg.lr, betas=cfg.betas) scheduler = get_scheduler( @@ -71,12 +102,13 @@ def train_sae( scheduler.step() - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar = tqdm(total=total_training_tokens, desc="Training SAE", smoothing=0.01) while n_training_tokens < total_training_tokens: sae.train() sae.update_l1_coefficient(n_training_steps) # Get the next batch of activations + batch = activation_store.next(batch_size=cfg.train_batch_size) assert batch is not None, "Activation store is empty" activation_in, activation_out = ( @@ -90,7 +122,6 @@ def train_sae( ghost_grad_neuron_mask = ( n_forward_passes_since_fired > cfg.dead_feature_window ).bool() - # Forward pass ( loss, @@ -107,19 +138,25 @@ def train_sae( did_fire = (aux_data["feature_acts"] > 0).float().sum(0) > 0 n_forward_passes_since_fired += 1 n_forward_passes_since_fired[did_fire] = 0 - if cfg.use_ddp: - dist.all_reduce(n_forward_passes_since_fired, op=dist.ReduceOp.MIN) + if cfg.sae.ddp_size > 1: + dist.all_reduce( + n_forward_passes_since_fired, + op=dist.ReduceOp.MIN, + ) if cfg.finetuning: loss = loss_data["l_rec"].mean() - loss.backward() + if cfg.sae.tp_size > 1: + with loss_parallel(): + loss.backward() + else: + loss.backward() if cfg.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm) if cfg.remove_gradient_parallel_to_decoder_directions: sae.remove_gradient_parallel_to_decoder_directions() optimizer.step() - if not cfg.sae.sparsity_include_decoder_norm: sae.set_decoder_norm_to_fixed_norm(1) with torch.no_grad(): @@ -129,16 +166,16 @@ def train_sae( n_tokens_current = torch.tensor( activation_in.size(0), device=cfg.sae.device, dtype=torch.int ) - if cfg.use_ddp: + if cfg.sae.ddp_size > 1: dist.reduce(n_tokens_current, dst=0) n_training_tokens += cast(int, n_tokens_current.item()) # log and then reset the feature sparsity every feature_sampling_window steps if (n_training_steps + 1) % cfg.feature_sampling_window == 0: - if cfg.use_ddp: + if cfg.sae.ddp_size > 1: dist.reduce(act_freq_scores, dst=0) dist.reduce(n_frac_active_tokens, dst=0) - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (is_master()): feature_sparsity = act_freq_scores / n_frac_active_tokens log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) wandb_histogram = wandb.Histogram( @@ -170,12 +207,16 @@ def train_sae( l_l1 = loss_data["l_l1"].mean() l_ghost_resid = loss_data["l_ghost_resid"].mean() - if cfg.use_ddp: + if cfg.sae.ddp_size > 1: dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG) dist.reduce(l0, dst=0, op=dist.ReduceOp.AVG) dist.reduce(l_rec, dst=0, op=dist.ReduceOp.AVG) dist.reduce(l_l1, dst=0, op=dist.ReduceOp.AVG) - dist.reduce(l_ghost_resid, dst=0, op=dist.ReduceOp.AVG) + dist.reduce( + l_ghost_resid, + dst=0, + op=dist.ReduceOp.AVG, + ) per_token_l2_loss = ( (aux_data["reconstructed"] - activation_out).pow(2).sum(dim=-1) @@ -189,11 +230,19 @@ def train_sae( l2_norm_error / activation_out.norm(p=2, dim=-1).mean() ) - if cfg.use_ddp: - dist.reduce(l2_norm_error, dst=0, op=dist.ReduceOp.AVG) - dist.reduce(l2_norm_error_ratio, dst=0, op=dist.ReduceOp.AVG) + if cfg.sae.ddp_size > 1: + dist.reduce( + l2_norm_error, + dst=0, + op=dist.ReduceOp.AVG, + ) + dist.reduce( + l2_norm_error_ratio, + dst=0, + op=dist.ReduceOp.AVG, + ) - if cfg.rank == 0: + if dist.get_rank() == 0: per_token_l2_loss_list = [ torch.zeros_like(per_token_l2_loss) for _ in range(dist.get_world_size()) @@ -204,15 +253,15 @@ def train_sae( ] dist.gather( per_token_l2_loss, - per_token_l2_loss_list if cfg.rank == 0 else None, + per_token_l2_loss_list if dist.get_rank() == 0 else None, dst=0, ) dist.gather( total_variance, - total_variance_list if cfg.rank == 0 else None, + total_variance_list if dist.get_rank() == 0 else None, dst=0, ) - if cfg.rank == 0: + if dist.get_rank() == 0: per_token_l2_loss = torch.cat(per_token_l2_loss_list, dim=0) total_variance = torch.cat(total_variance_list, dim=0) @@ -222,7 +271,7 @@ def train_sae( current_learning_rate = optimizer.param_groups[0]["lr"] - if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and is_master(): wandb.log( { # losses @@ -246,10 +295,10 @@ def train_sae( "sparsity/l1_coefficient": sae.current_l1_coefficient, "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(), "sparsity/dead_features": ghost_grad_neuron_mask.sum().item(), - "sparsity/useful_features": sae.decoder.norm(p=2, dim=1) - .gt(0.99) - .sum() - .item(), + # "sparsity/useful_features": sae.decoder.weight.norm(p=2, dim=1) + # .gt(0.99) + # .sum() + # .item(), "details/current_learning_rate": current_learning_rate, "details/n_training_tokens": n_training_tokens, }, @@ -272,7 +321,7 @@ def train_sae( if ( len(checkpoint_thresholds) > 0 and n_training_tokens >= checkpoint_thresholds[0] - and (not cfg.use_ddp or cfg.rank == 0) + and is_master() ): # Save the model and optimizer state path = os.path.join( @@ -289,7 +338,7 @@ def train_sae( n_training_steps += 1 - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): l_rec = loss_data["l_rec"].mean().item() l_l1 = loss_data["l_l1"].mean().item() pbar.set_description( @@ -297,11 +346,11 @@ def train_sae( ) pbar.update(n_tokens_current.item()) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar.close() # Save the final model - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): path = os.path.join( cfg.exp_result_dir, cfg.exp_name, "checkpoints", "final.safetensors" ) @@ -324,10 +373,10 @@ def prune_sae( max_acts = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype) activation_store.initialize() - if cfg.use_ddp: - ddp = DDP(sae, device_ids=[cfg.rank], output_device=cfg.sae.device) + if cfg.sae.ddp_size > 1: + _ = DDP(sae, device_mesh=sae.device_mesh["ddp"]) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar = tqdm(total=cfg.total_training_tokens, desc="Pruning SAE", smoothing=0.01) while n_training_tokens < cfg.total_training_tokens: # Get the next batch of activations @@ -344,21 +393,21 @@ def prune_sae( max_acts = torch.max(max_acts, feature_acts.max(0).values) n_tokens_current = activation_in.size(0) - if cfg.use_ddp: + if cfg.sae.ddp_size > 1: dist.reduce(n_tokens_current, dst=0) n_training_tokens += n_tokens_current - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar.update(n_tokens_current) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): pbar.close() - if cfg.use_ddp: + if cfg.sae.ddp_size > 1: dist.reduce(act_times, dst=0, op=dist.ReduceOp.SUM) dist.reduce(max_acts, dst=0, op=dist.ReduceOp.MAX) - if not cfg.use_ddp or cfg.rank == 0: + if is_master(): sae.feature_act_mask.data = ( (act_times > cfg.dead_feature_threshold * cfg.total_training_tokens) & (max_acts > cfg.dead_feature_max_act_threshold) diff --git a/src/lm_saes/utils/misc.py b/src/lm_saes/utils/misc.py index b99de82..df81ba8 100644 --- a/src/lm_saes/utils/misc.py +++ b/src/lm_saes/utils/misc.py @@ -3,16 +3,15 @@ import torch import torch.distributed as dist +def is_master() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 def print_once( *values: object, sep: str | None = " ", end: str | None = "\n", ) -> None: - if dist.is_initialized(): - if dist.get_rank() == 0: - print(*values, sep=sep, end=end) - else: + if is_master(): print(*values, sep=sep, end=end) def check_file_path_unused(file_path):