Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support tensor parallel analysis #51

Merged
merged 5 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/configuration/analyze.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dtype = "torch.float32"

exp_name = "L3M"
exp_series = "default"
exp_result_dir = "results"
exp_result_path = "results/L3M"

[subsample]
"top_activations" = { "proportion" = 1.0, "n_samples" = 80 }
Expand Down
2 changes: 1 addition & 1 deletion examples/configuration/prune.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dtype = "torch.float32"

exp_name = "L3M"
exp_series = "default"
exp_result_dir = "results"
exp_result_path = "results/L3M"

total_training_tokens = 10_000_000
train_batch_size = 4096
Expand Down
2 changes: 1 addition & 1 deletion examples/configuration/train.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use_ddp = false
exp_name = "L3M"
exp_result_dir = "results"
exp_result_path = "results/L3M"
device = "cuda"
seed = 42
dtype = "torch.float32"
Expand Down
2 changes: 1 addition & 1 deletion examples/programmatic/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

exp_name = "L3M",
exp_series = "default",
exp_result_dir = "results",
exp_result_path = "results/L3M",
))

sample_feature_activations_runner(cfg)
2 changes: 1 addition & 1 deletion examples/programmatic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "test",
exp_result_dir = "results"
exp_result_path = "results/test"
))

sparse_autoencoder = language_model_sae_runner(cfg)
4 changes: 3 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@


def get_model(dictionary_name: str) -> HookedTransformer:
path = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)['path'] or f"{result_dir}/{dictionary_name}"
path = client.get_dictionary_path(dictionary_name, dictionary_series=dictionary_series)
if path is "":
path = f"{result_dir}/{dictionary_name}"
cfg = LanguageModelConfig.from_pretrained_sae(path)
if (cfg.model_name, cfg.model_from_pretrained_path) not in lm_cache:
hf_model = AutoModelForCausalLM.from_pretrained(
Expand Down
163 changes: 124 additions & 39 deletions src/lm_saes/analysis/sample_feature_activations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import cast

from torch.distributed._tensor import DTensor
from tqdm import tqdm

import torch
Expand All @@ -16,6 +17,8 @@
from lm_saes.activation.activation_store import ActivationStore
from lm_saes.utils.misc import print_once
from lm_saes.utils.tensor_dict import concat_dict_of_tensor, sort_dict_of_tensor
import torch.distributed as dist


@torch.no_grad()
def sample_feature_activations(
Expand All @@ -28,10 +31,14 @@ def sample_feature_activations(
):
if sae.cfg.ddp_size > 1:
raise ValueError("Sampling feature activations does not support DDP yet")
assert cfg.sae.d_sae is not None # Make mypy happy
assert cfg.sae.d_sae is not None # Make mypy happy

total_analyzing_tokens = cfg.total_analyzing_tokens
total_analyzing_steps = total_analyzing_tokens // cfg.act_store.dataset.store_batch_size // cfg.act_store.dataset.context_size
total_analyzing_steps = (
total_analyzing_tokens
// cfg.act_store.dataset.store_batch_size
// cfg.act_store.dataset.context_size
)

print_once(f"Total Analyzing Tokens: {total_analyzing_tokens}")
print_once(f"Total Analyzing Steps: {total_analyzing_steps}")
Expand All @@ -41,19 +48,43 @@ def sample_feature_activations(

sae.eval()

pbar = tqdm(total=total_analyzing_tokens, desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}", smoothing=0.01)
pbar = tqdm(
total=total_analyzing_tokens,
desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}",
smoothing=0.01,
)

d_sae = cfg.sae.d_sae // n_sae_chunks
start_index = sae_chunk_id * d_sae
end_index = (sae_chunk_id + 1) * d_sae

sample_result = {k: {
"elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device),
"feature_acts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=cfg.sae.dtype, device=cfg.sae.device),
"contexts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=torch.int32, device=cfg.sae.device),
} for k in cfg.subsample.keys()}
assert (
d_sae // cfg.sae.tp_size * cfg.sae.tp_size == d_sae
), "d_sae must be divisible by tp_size"
d_sae //= cfg.sae.tp_size

rank = dist.get_rank() if cfg.sae.tp_size > 1 else 0
start_index = sae_chunk_id * d_sae * cfg.sae.tp_size + d_sae * rank
end_index = sae_chunk_id * d_sae * cfg.sae.tp_size + d_sae * (rank + 1)

sample_result = {
k: {
"elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device),
"feature_acts": torch.empty(
(0, d_sae, cfg.act_store.dataset.context_size),
dtype=cfg.sae.dtype,
device=cfg.sae.device,
),
"contexts": torch.empty(
(0, d_sae, cfg.act_store.dataset.context_size),
dtype=torch.int32,
device=cfg.sae.device,
),
}
for k in cfg.subsample.keys()
}
act_times = torch.zeros((d_sae,), dtype=torch.long, device=cfg.sae.device)
feature_acts_all = [torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device) for _ in range(d_sae)]
feature_acts_all = [
torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device)
for _ in range(d_sae)
]
max_feature_acts = torch.zeros((d_sae,), dtype=cfg.sae.dtype, device=cfg.sae.device)

while n_training_tokens < total_analyzing_tokens:
Expand All @@ -62,76 +93,129 @@ def sample_feature_activations(
if batch is None:
raise ValueError("Not enough tokens to sample")

_, cache = model.run_with_cache_until(batch, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out)
activation_in, activation_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out]
_, cache = model.run_with_cache_until(
batch,
names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out],
until=cfg.sae.hook_point_out,
)
activation_in, activation_out = (
cache[cfg.sae.hook_point_in],
cache[cfg.sae.hook_point_out],
)

filter_mask = torch.logical_or(
batch.eq(model.tokenizer.eos_token_id),
batch.eq(model.tokenizer.pad_token_id)
batch.eq(model.tokenizer.pad_token_id),
)
filter_mask = torch.logical_or(
filter_mask,
batch.eq(model.tokenizer.bos_token_id)
filter_mask, batch.eq(model.tokenizer.bos_token_id)
)

feature_acts = sae.encode(activation_in, label=activation_out)[..., start_index: end_index]
feature_acts = sae.encode(activation_in, label=activation_out)[
..., start_index:end_index
]
if isinstance(feature_acts, DTensor):
feature_acts = feature_acts.to_local()

feature_acts[filter_mask] = 0

act_times += feature_acts.gt(0.0).sum(dim=[0, 1])

for name in cfg.subsample.keys():

if cfg.enable_sampling:
weights = feature_acts.clamp(min=0.0).pow(cfg.sample_weight_exponent).max(dim=1).values
elt = torch.rand(batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype).log() / weights
weights = (
feature_acts.clamp(min=0.0)
.pow(cfg.sample_weight_exponent)
.max(dim=1)
.values
)
elt = (
torch.rand(
batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype
).log()
/ weights
)
elt[weights == 0.0] = -torch.inf
else:
elt = feature_acts.clamp(min=0.0).max(dim=1).values

elt[feature_acts.max(dim=1).values > max_feature_acts.unsqueeze(0) * cfg.subsample[name]["proportion"]] = -torch.inf
elt[
feature_acts.max(dim=1).values
> max_feature_acts.unsqueeze(0) * cfg.subsample[name]["proportion"]
] = -torch.inf

if sample_result[name]["elt"].size(0) > 0 and (elt.max(dim=0).values <= sample_result[name]["elt"][-1]).all():
if (
sample_result[name]["elt"].size(0) > 0
and (elt.max(dim=0).values <= sample_result[name]["elt"][-1]).all()
):
continue

sample_result[name] = concat_dict_of_tensor(
sample_result[name],
{
"elt": elt,
"feature_acts": rearrange(feature_acts, 'batch_size context_size d_sae -> batch_size d_sae context_size'),
"contexts": repeat(batch.to(torch.int32), 'batch_size context_size -> batch_size d_sae context_size', d_sae=d_sae),
"feature_acts": rearrange(
feature_acts,
"batch_size context_size d_sae -> batch_size d_sae context_size",
),
"contexts": repeat(
batch.to(torch.int32),
"batch_size context_size -> batch_size d_sae context_size",
d_sae=d_sae,
),
},
dim=0,
)

sample_result[name] = sort_dict_of_tensor(sample_result[name], sort_dim=0, sort_key="elt", descending=True)
sample_result[name] = sort_dict_of_tensor(
sample_result[name], sort_dim=0, sort_key="elt", descending=True
)
sample_result[name] = {
k: v[:cfg.subsample[name]["n_samples"]] for k, v in sample_result[name].items()
k: v[: cfg.subsample[name]["n_samples"]]
for k, v in sample_result[name].items()
}


# Update feature activation histogram every 10 steps
if n_training_steps % 50 == 49:
feature_acts_cur = rearrange(feature_acts, 'batch_size context_size d_sae -> d_sae (batch_size context_size)')
feature_acts_cur = rearrange(
feature_acts,
"batch_size context_size d_sae -> d_sae (batch_size context_size)",
)
for i in range(d_sae):
feature_acts_all[i] = torch.cat([feature_acts_all[i], feature_acts_cur[i][feature_acts_cur[i] > 0.0]], dim=0)

max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values)
feature_acts_all[i] = torch.cat(
[
feature_acts_all[i],
feature_acts_cur[i][feature_acts_cur[i] > 0.0],
],
dim=0,
)

max_feature_acts = torch.max(
max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values
)

n_tokens_current = torch.tensor(batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int)
n_tokens_current = torch.tensor(
batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int
)
n_training_tokens += cast(int, n_tokens_current.item())
n_training_steps += 1

pbar.update(n_tokens_current.item())

pbar.close()

sample_result = {k1: {
k2: rearrange(v2, 'n_samples d_sae ... -> d_sae n_samples ...') for k2, v2 in v1.items()
} for k1, v1 in sample_result.items()}
sample_result = {
k1: {
k2: rearrange(v2, "n_samples d_sae ... -> d_sae n_samples ...")
for k2, v2 in v1.items()
}
for k1, v1 in sample_result.items()
}

result = {
"index": torch.arange(start_index, end_index, device=cfg.sae.device, dtype=torch.int32),
"index": torch.arange(
start_index, end_index, device=cfg.sae.device, dtype=torch.int32
),
"act_times": act_times,
"feature_acts_all": feature_acts_all,
"max_feature_acts": max_feature_acts,
Expand All @@ -140,8 +224,9 @@ def sample_feature_activations(
"name": k,
"feature_acts": v["feature_acts"],
"contexts": v["contexts"],
} for k, v in sample_result.items()
}
for k, v in sample_result.items()
],
}

return result
return result
27 changes: 13 additions & 14 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def __post_init__(self):
class RunnerConfig(BaseConfig):
exp_name: str = "test"
exp_series: Optional[str] = None
exp_result_dir: str = "results"
exp_result_path: str = "results"

def __post_init__(self):
super().__post_init__()
if is_master():
os.makedirs(self.exp_result_dir, exist_ok=True)
os.makedirs(os.path.join(self.exp_result_dir, self.exp_name), exist_ok=True)
os.makedirs(self.exp_result_path, exist_ok=True)


@dataclass(kw_only=True)
Expand Down Expand Up @@ -257,12 +256,12 @@ def from_pretrained(
@deprecated("Use from_pretrained and to_dict instead.")
@staticmethod
def get_hyperparameters(
exp_name: str, exp_result_dir: str, ckpt_name: str, strict_loading: bool = True
exp_result_path: str, ckpt_name: str, strict_loading: bool = True
) -> dict[str, Any]:
with open(os.path.join(exp_result_dir, exp_name, "hyperparams.json"), "r") as f:
with open(os.path.join(exp_result_path, "hyperparams.json"), "r") as f:
hyperparams = json.load(f)
hyperparams["sae_pretrained_name_or_path"] = os.path.join(
exp_result_dir, exp_name, "checkpoints", ckpt_name
exp_result_path, "checkpoints", ckpt_name
)
hyperparams["strict_loading"] = strict_loading
# Remove non-hyperparameters from the dict
Expand Down Expand Up @@ -350,13 +349,13 @@ def __post_init__(self):
super().__post_init__()

if is_master():
# if os.path.exists(
# os.path.join(self.exp_result_dir, self.exp_name, "checkpoints")
# ):
# raise ValueError(
# f"Checkpoints for experiment {self.exp_name} already exist. Consider changing the experiment name."
# )
os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), exist_ok=True)
if os.path.exists(
os.path.join(self.exp_result_path, "checkpoints")
):
raise ValueError(
f"Checkpoints for experiment {self.exp_result_path} already exist. Consider changing the experiment name."
)
os.makedirs(os.path.join(self.exp_result_path, "checkpoints"))

self.effective_batch_size = self.train_batch_size * self.sae.ddp_size
print_once(f"Effective batch size: {self.effective_batch_size}")
Expand Down Expand Up @@ -418,7 +417,7 @@ def __post_init__(self):

if is_master():
os.makedirs(
os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"),
os.path.join(self.exp_result_path, "checkpoints"),
exist_ok=True,
)

Expand Down
Loading