Skip to content

Commit

Permalink
feat: Implement tensor parallelism in SAE using device mesh
Browse files Browse the repository at this point in the history
- Introduced tensor parallelism which may contain bugs; further testing required. Existing data parallelism remains functional.
- Save and load functionalities for SAE under tensor parallelism are yet to be implemented.
- Additional tests needed for SAE and activation store under tensor parallelism.
  • Loading branch information
Frankstein73 committed Jul 14, 2024
1 parent 5bf9e6a commit ccac63a
Show file tree
Hide file tree
Showing 11 changed files with 558 additions and 290 deletions.
13 changes: 6 additions & 7 deletions src/lm_saes/activation/activation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import os
import torch.distributed as dist

from lm_saes.utils.misc import print_once
from lm_saes.config import ActivationGenerationConfig
from lm_saes.activation.token_source import TokenSource

from lm_saes.utils.misc import is_master, print_once

@torch.no_grad()
def make_activation_dataset(
Expand All @@ -22,19 +21,19 @@ def make_activation_dataset(

token_source = TokenSource.from_config(model=model, cfg=cfg.dataset)

if not cfg.use_ddp or cfg.rank == 0:
if is_master():
for hook_point in cfg.hook_points:
os.makedirs(os.path.join(cfg.activation_save_path, hook_point), exist_ok=False)

if cfg.use_ddp:
if cfg.ddp_size > 1:
dist.barrier()
total_generating_tokens = cfg.total_generating_tokens // cfg.world_size
total_generating_tokens = cfg.total_generating_tokens // dist.get_world_size()
else:
total_generating_tokens = cfg.total_generating_tokens

n_tokens = 0
chunk_idx = 0
pbar = tqdm(total=total_generating_tokens, desc=f"Activation dataset Rank {cfg.rank}" if cfg.use_ddp else "Activation dataset")
pbar = tqdm(total=total_generating_tokens, desc=f"Activation dataset Rank {dist.get_rank()}" if dist.is_initialized() else "Activation dataset")

while n_tokens < total_generating_tokens:
act_dict = {hook_point: torch.empty((0, cfg.dataset.context_size, cfg.lm.d_model), dtype=cfg.lm.dtype, device=cfg.lm.device) for hook_point in cfg.hook_points}
Expand Down Expand Up @@ -63,7 +62,7 @@ def make_activation_dataset(
"context": context,
"position": position,
},
os.path.join(cfg.activation_save_path, hook_point, f"chunk-{str(chunk_idx).zfill(5)}.pt" if not cfg.use_ddp else f"shard-{cfg.rank}-chunk-{str(chunk_idx).zfill(5)}.pt")
os.path.join(cfg.activation_save_path, hook_point, f"chunk-{str(chunk_idx).zfill(5)}.pt" if not dist.is_initialized() else f"shard-{dist.get_rank()}-chunk-{str(chunk_idx).zfill(5)}.pt")
)
chunk_idx += 1

Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self, cfg: ActivationStoreConfig):
assert len(cfg.hook_points) == 1, "CachedActivationSource only supports one hook point"
self.hook_point = cfg.hook_points[0]
self.chunk_paths = list_activation_chunks(cfg.cached_activations_path[0], self.hook_point)
if cfg.use_ddp:
self.chunk_paths = [p for i, p in enumerate(self.chunk_paths) if i % cfg.world_size == cfg.rank]
if cfg.ddp_size > 1:
self.chunk_paths = [p for i, p in enumerate(self.chunk_paths) if i % dist.get_world_size() == dist.get_rank()]
random.shuffle(self.chunk_paths)

self.token_buffer = torch.empty((0, cfg.dataset.context_size), dtype=torch.long, device=cfg.device)
Expand Down
102 changes: 72 additions & 30 deletions src/lm_saes/activation/activation_store.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
from typing import Callable, Dict, Generator, Iterable
import torch
from torch.cuda import is_initialized
import torch.distributed as dist
import torch.utils.data

from transformer_lens import HookedTransformer

from lm_saes.config import ActivationStoreConfig
from lm_saes.activation.activation_source import ActivationSource, CachedActivationSource, TokenActivationSource
from lm_saes.activation.activation_source import (
ActivationSource,
CachedActivationSource,
TokenActivationSource,
)
from lm_saes.activation.token_source import TokenSource
import math
from torch.distributed.device_mesh import init_device_mesh
import torch.distributed._functional_collectives as funcol


class ActivationStore:
def __init__(self,
act_source: ActivationSource,
d_model: int,
n_tokens_in_buffer=500000,
device="cuda",
use_ddp=False
):

def __init__(self, act_source: ActivationSource, cfg: ActivationStoreConfig):
self.act_source = act_source
self.d_model = d_model
self.buffer_size = n_tokens_in_buffer
self.device = device
self.use_ddp = use_ddp
self.buffer_size = cfg.n_tokens_in_buffer
self.device = cfg.device
self.ddp_size = cfg.ddp_size # 1 8
self.tp_size = cfg.tp_size
self._store: Dict[str, torch.Tensor] = {}

self._all_gather_buffer: Dict[str, torch.Tensor] = {}
self.device_mesh = init_device_mesh(
"cuda", (self.ddp_size, self.tp_size), mesh_dim_names=("ddp", "tp")
)

def initialize(self):
self.refill()
self.shuffle()
Expand Down Expand Up @@ -56,37 +64,71 @@ def __len__(self):

def next(self, batch_size) -> Dict[str, torch.Tensor] | None:
# Check if the activation store needs to be refilled.
need_refill = torch.tensor([self.__len__() < self.buffer_size // 2], device=self.device, dtype=torch.int)
if self.use_ddp: # When using DDP, we do refills in a synchronized manner to save time
need_refill = torch.tensor(
[self.__len__() < self.buffer_size // 2],
device=self.device,
dtype=torch.int,
)
if (
dist.is_initialized()
): # When using DDP, we do refills in a synchronized manner to save time
dist.all_reduce(need_refill, op=dist.ReduceOp.MAX)
if need_refill.item() > 0:
self.refill()
self.shuffle()
if self.use_ddp: # Wait for all processes to refill the store
if dist.is_initialized(): # Wait for all processes to refill the store
dist.barrier()
if self.tp_size > 1:
for k, v in self._store.items():
if k not in self._all_gather_buffer:
self._all_gather_buffer[k] = torch.empty(size=(0,), dtype=v.dtype, device=self.device)

gather_len = math.ceil(
(batch_size - self._all_gather_buffer[k].size(0)) / self.tp_size
)

assert gather_len <= v.size(0), "Not enough activations in the store"
gather_tensor = funcol.all_gather_tensor(
v[:gather_len], gather_dim=0, group=self.device_mesh["tp"]
)
self._store[k] = v[gather_len:]

self._all_gather_buffer[k] = torch.cat(
[self._all_gather_buffer[k], gather_tensor], dim=0
)

ret = {k: self._all_gather_buffer[k][:batch_size] for k in self._store}
for k in self._store:
self._all_gather_buffer[k] = self._all_gather_buffer[k][batch_size:]
return ret if len(ret) > 0 else None

else:

ret = {k: v[:batch_size] for k, v in self._store.items()}
for k in self._store:
self._store[k] = self._store[k][batch_size:]
return ret if len(ret) > 0 else None

ret = {k: v[:batch_size] for k, v in self._store.items()}
for k in self._store:
self._store[k] = self._store[k][batch_size:]
return ret if len(ret) > 0 else None

def next_tokens(self, batch_size: int) -> torch.Tensor | None:
return self.act_source.next_tokens(batch_size)

if self.tp_size > 1:
# TODO
next_tokens = self.act_source.next_tokens(batch_size)
funcol.broadcast(next_tokens, src=0, group=self.device_mesh["tp"])
return next_tokens
else:
return self.act_source.next_tokens(batch_size)

@staticmethod
def from_config(model: HookedTransformer, cfg: ActivationStoreConfig):
act_source: ActivationSource
if cfg.use_cached_activations:
act_source=CachedActivationSource(cfg=cfg)
act_source = CachedActivationSource(cfg=cfg)
else:
act_source=TokenActivationSource(
act_source = TokenActivationSource(
model=model,
cfg=cfg,
)
return ActivationStore(
act_source=act_source,
d_model=cfg.lm.d_model,
n_tokens_in_buffer=cfg.n_tokens_in_buffer,
device=cfg.device,
use_ddp=cfg.use_ddp,
)
cfg=cfg,
)
61 changes: 39 additions & 22 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from datasets import load_dataset, load_from_disk, Dataset

from transformer_lens import HookedTransformer

import torch.distributed as dist
from lm_saes.config import TextDatasetConfig
import random


class TokenSource:
def __init__(
self,
Expand All @@ -28,9 +29,13 @@ def __init__(

self.data_iter = [iter(dataloader) for dataloader in self.dataloader]

self.token_buffer = torch.empty((0, seq_len), dtype=torch.long, device=self.device)
self.token_buffer = torch.empty(
(0, seq_len), dtype=torch.long, device=self.device
)

self.bos_token_id_tensor = torch.tensor([self.model.tokenizer.bos_token_id], dtype=torch.long, device=self.device)
self.bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id], dtype=torch.long, device=self.device
)
self.resid = torch.tensor([], dtype=torch.long, device=self.device)

self.sample_probs = sample_probs
Expand All @@ -49,31 +54,43 @@ def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None:
cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.eos_token_id]
cur_tokens = cur_tokens[cur_tokens != self.model.tokenizer.pad_token_id]

self.resid = torch.cat([self.resid, self.bos_token_id_tensor.clone(), cur_tokens], dim=0)
self.resid = torch.cat(
[self.resid, self.bos_token_id_tensor.clone(), cur_tokens], dim=0
)
while self.resid.size(0) >= self.seq_len:
self.token_buffer = torch.cat([self.token_buffer, self.resid[:self.seq_len].unsqueeze(0)], dim=0)
self.resid = self.resid[self.seq_len:]
self.resid = torch.cat([self.bos_token_id_tensor.clone(), self.resid], dim=0)
self.token_buffer = torch.cat(
[self.token_buffer, self.resid[: self.seq_len].unsqueeze(0)],
dim=0,
)
self.resid = self.resid[self.seq_len :]
self.resid = torch.cat(
[self.bos_token_id_tensor.clone(), self.resid], dim=0
)
tokens = tokens[1:]
else:
tokens = tokens[:, :self.seq_len]
tokens = tokens[:, : self.seq_len]

if tokens.size(1) < self.seq_len:
pad_len = self.seq_len - tokens.size(1)
tokens = torch.cat([tokens, torch.full((tokens.size(0), pad_len), self.model.tokenizer.pad_token_id, dtype=torch.long, device=self.device)], dim=1)
self.token_buffer = torch.cat([self.token_buffer, tokens], dim=0)


def reset_iter(self, empty_idx: int):
self.data_iter = self.data_iter[:empty_idx] + self.data_iter[empty_idx + 1:]
self.data_iter = self.data_iter[:empty_idx] + self.data_iter[empty_idx + 1 :]

self.sample_probs = self.sample_probs[:empty_idx] + self.sample_probs[empty_idx + 1:]
self.sample_probs = (
self.sample_probs[:empty_idx] + self.sample_probs[empty_idx + 1 :]
)

self.sample_probs = [prob / sum(self.sample_probs) for prob in self.sample_probs]
self.sample_probs = [
prob / sum(self.sample_probs) for prob in self.sample_probs
]

def next(self, batch_size: int) -> torch.Tensor | None:
while self.token_buffer.size(0) < batch_size:
dataset_idx_to_fetch = random.choices(range(len(self.dataloader)), weights=self.sample_probs)[0]
dataset_idx_to_fetch = random.choices(
range(len(self.dataloader)), weights=self.sample_probs
)[0]
try:
batch = next(self.data_iter[dataset_idx_to_fetch])
except StopIteration:
Expand All @@ -96,24 +113,24 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig):
dataset = load_dataset(dataset_path, split="train", cache_dir=cfg.cache_dir)
else:
dataset = load_from_disk(dataset_path)
if cfg.use_ddp:
shard_id = cfg.rank
shard = dataset.shard(num_shards=cfg.world_size, index=shard_id)
if dist.is_initialized():
shard_id = dist.get_rank()
shard = dataset.shard(
num_shards=dist.get_world_size(), index=shard_id
)
else:
shard = dataset

if cfg.use_ddp:
shard_id = cfg.rank
shard = dataset.shard(num_shards=cfg.world_size, index=shard_id)
else:
shard = dataset

dataloader = DataLoader(shard, batch_size=cfg.store_batch_size)
return dataloader

@staticmethod
def from_config(model: HookedTransformer, cfg: TextDatasetConfig):
dataloader = [TokenSource._process_dataset(dataset_path, cfg) for dataset_path in cfg.dataset_path]
dataloader = [
TokenSource._process_dataset(dataset_path, cfg)
for dataset_path in cfg.dataset_path
]

return TokenSource(
dataloader=dataloader,
Expand Down
Loading

0 comments on commit ccac63a

Please sign in to comment.