Skip to content

Commit

Permalink
Merge pull request #51 from OpenMOSS/dev
Browse files Browse the repository at this point in the history
support tensor parallel analysis
  • Loading branch information
Hzfinfdu authored Aug 24, 2024
2 parents 05e13f6 + 5f578f9 commit 2d23c75
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 78 deletions.
2 changes: 1 addition & 1 deletion examples/configuration/analyze.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dtype = "torch.float32"

exp_name = "L3M"
exp_series = "default"
exp_result_dir = "results"
exp_result_path = "results/L3M"

[subsample]
"top_activations" = { "proportion" = 1.0, "n_samples" = 80 }
Expand Down
2 changes: 1 addition & 1 deletion examples/configuration/prune.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dtype = "torch.float32"

exp_name = "L3M"
exp_series = "default"
exp_result_dir = "results"
exp_result_path = "results/L3M"

total_training_tokens = 10_000_000
train_batch_size = 4096
Expand Down
2 changes: 1 addition & 1 deletion examples/configuration/train.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use_ddp = false
exp_name = "L3M"
exp_result_dir = "results"
exp_result_path = "results/L3M"
device = "cuda"
seed = 42
dtype = "torch.float32"
Expand Down
2 changes: 1 addition & 1 deletion examples/programmatic/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

exp_name = "L3M",
exp_series = "default",
exp_result_dir = "results",
exp_result_path = "results/L3M",
))

sample_feature_activations_runner(cfg)
2 changes: 1 addition & 1 deletion examples/programmatic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "test",
exp_result_dir = "results"
exp_result_path = "results/test"
))

sparse_autoencoder = language_model_sae_runner(cfg)
4 changes: 3 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@


def get_model(dictionary_name: str) -> HookedTransformer:
path = client.get_dictionary(dictionary_name, dictionary_series=dictionary_series)['path'] or f"{result_dir}/{dictionary_name}"
path = client.get_dictionary_path(dictionary_name, dictionary_series=dictionary_series)
if path is "":
path = f"{result_dir}/{dictionary_name}"
cfg = LanguageModelConfig.from_pretrained_sae(path)
if (cfg.model_name, cfg.model_from_pretrained_path) not in lm_cache:
hf_model = AutoModelForCausalLM.from_pretrained(
Expand Down
163 changes: 124 additions & 39 deletions src/lm_saes/analysis/sample_feature_activations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import cast

from torch.distributed._tensor import DTensor
from tqdm import tqdm

import torch
Expand All @@ -16,6 +17,8 @@
from lm_saes.activation.activation_store import ActivationStore
from lm_saes.utils.misc import print_once
from lm_saes.utils.tensor_dict import concat_dict_of_tensor, sort_dict_of_tensor
import torch.distributed as dist


@torch.no_grad()
def sample_feature_activations(
Expand All @@ -28,10 +31,14 @@ def sample_feature_activations(
):
if sae.cfg.ddp_size > 1:
raise ValueError("Sampling feature activations does not support DDP yet")
assert cfg.sae.d_sae is not None # Make mypy happy
assert cfg.sae.d_sae is not None # Make mypy happy

total_analyzing_tokens = cfg.total_analyzing_tokens
total_analyzing_steps = total_analyzing_tokens // cfg.act_store.dataset.store_batch_size // cfg.act_store.dataset.context_size
total_analyzing_steps = (
total_analyzing_tokens
// cfg.act_store.dataset.store_batch_size
// cfg.act_store.dataset.context_size
)

print_once(f"Total Analyzing Tokens: {total_analyzing_tokens}")
print_once(f"Total Analyzing Steps: {total_analyzing_steps}")
Expand All @@ -41,19 +48,43 @@ def sample_feature_activations(

sae.eval()

pbar = tqdm(total=total_analyzing_tokens, desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}", smoothing=0.01)
pbar = tqdm(
total=total_analyzing_tokens,
desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}",
smoothing=0.01,
)

d_sae = cfg.sae.d_sae // n_sae_chunks
start_index = sae_chunk_id * d_sae
end_index = (sae_chunk_id + 1) * d_sae

sample_result = {k: {
"elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device),
"feature_acts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=cfg.sae.dtype, device=cfg.sae.device),
"contexts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=torch.int32, device=cfg.sae.device),
} for k in cfg.subsample.keys()}
assert (
d_sae // cfg.sae.tp_size * cfg.sae.tp_size == d_sae
), "d_sae must be divisible by tp_size"
d_sae //= cfg.sae.tp_size

rank = dist.get_rank() if cfg.sae.tp_size > 1 else 0
start_index = sae_chunk_id * d_sae * cfg.sae.tp_size + d_sae * rank
end_index = sae_chunk_id * d_sae * cfg.sae.tp_size + d_sae * (rank + 1)

sample_result = {
k: {
"elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device),
"feature_acts": torch.empty(
(0, d_sae, cfg.act_store.dataset.context_size),
dtype=cfg.sae.dtype,
device=cfg.sae.device,
),
"contexts": torch.empty(
(0, d_sae, cfg.act_store.dataset.context_size),
dtype=torch.int32,
device=cfg.sae.device,
),
}
for k in cfg.subsample.keys()
}
act_times = torch.zeros((d_sae,), dtype=torch.long, device=cfg.sae.device)
feature_acts_all = [torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device) for _ in range(d_sae)]
feature_acts_all = [
torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device)
for _ in range(d_sae)
]
max_feature_acts = torch.zeros((d_sae,), dtype=cfg.sae.dtype, device=cfg.sae.device)

while n_training_tokens < total_analyzing_tokens:
Expand All @@ -62,76 +93,129 @@ def sample_feature_activations(
if batch is None:
raise ValueError("Not enough tokens to sample")

_, cache = model.run_with_cache_until(batch, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out)
activation_in, activation_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out]
_, cache = model.run_with_cache_until(
batch,
names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out],
until=cfg.sae.hook_point_out,
)
activation_in, activation_out = (
cache[cfg.sae.hook_point_in],
cache[cfg.sae.hook_point_out],
)

filter_mask = torch.logical_or(
batch.eq(model.tokenizer.eos_token_id),
batch.eq(model.tokenizer.pad_token_id)
batch.eq(model.tokenizer.pad_token_id),
)
filter_mask = torch.logical_or(
filter_mask,
batch.eq(model.tokenizer.bos_token_id)
filter_mask, batch.eq(model.tokenizer.bos_token_id)
)

feature_acts = sae.encode(activation_in, label=activation_out)[..., start_index: end_index]
feature_acts = sae.encode(activation_in, label=activation_out)[
..., start_index:end_index
]
if isinstance(feature_acts, DTensor):
feature_acts = feature_acts.to_local()

feature_acts[filter_mask] = 0

act_times += feature_acts.gt(0.0).sum(dim=[0, 1])

for name in cfg.subsample.keys():

if cfg.enable_sampling:
weights = feature_acts.clamp(min=0.0).pow(cfg.sample_weight_exponent).max(dim=1).values
elt = torch.rand(batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype).log() / weights
weights = (
feature_acts.clamp(min=0.0)
.pow(cfg.sample_weight_exponent)
.max(dim=1)
.values
)
elt = (
torch.rand(
batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype
).log()
/ weights
)
elt[weights == 0.0] = -torch.inf
else:
elt = feature_acts.clamp(min=0.0).max(dim=1).values

elt[feature_acts.max(dim=1).values > max_feature_acts.unsqueeze(0) * cfg.subsample[name]["proportion"]] = -torch.inf
elt[
feature_acts.max(dim=1).values
> max_feature_acts.unsqueeze(0) * cfg.subsample[name]["proportion"]
] = -torch.inf

if sample_result[name]["elt"].size(0) > 0 and (elt.max(dim=0).values <= sample_result[name]["elt"][-1]).all():
if (
sample_result[name]["elt"].size(0) > 0
and (elt.max(dim=0).values <= sample_result[name]["elt"][-1]).all()
):
continue

sample_result[name] = concat_dict_of_tensor(
sample_result[name],
{
"elt": elt,
"feature_acts": rearrange(feature_acts, 'batch_size context_size d_sae -> batch_size d_sae context_size'),
"contexts": repeat(batch.to(torch.int32), 'batch_size context_size -> batch_size d_sae context_size', d_sae=d_sae),
"feature_acts": rearrange(
feature_acts,
"batch_size context_size d_sae -> batch_size d_sae context_size",
),
"contexts": repeat(
batch.to(torch.int32),
"batch_size context_size -> batch_size d_sae context_size",
d_sae=d_sae,
),
},
dim=0,
)

sample_result[name] = sort_dict_of_tensor(sample_result[name], sort_dim=0, sort_key="elt", descending=True)
sample_result[name] = sort_dict_of_tensor(
sample_result[name], sort_dim=0, sort_key="elt", descending=True
)
sample_result[name] = {
k: v[:cfg.subsample[name]["n_samples"]] for k, v in sample_result[name].items()
k: v[: cfg.subsample[name]["n_samples"]]
for k, v in sample_result[name].items()
}


# Update feature activation histogram every 10 steps
if n_training_steps % 50 == 49:
feature_acts_cur = rearrange(feature_acts, 'batch_size context_size d_sae -> d_sae (batch_size context_size)')
feature_acts_cur = rearrange(
feature_acts,
"batch_size context_size d_sae -> d_sae (batch_size context_size)",
)
for i in range(d_sae):
feature_acts_all[i] = torch.cat([feature_acts_all[i], feature_acts_cur[i][feature_acts_cur[i] > 0.0]], dim=0)

max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values)
feature_acts_all[i] = torch.cat(
[
feature_acts_all[i],
feature_acts_cur[i][feature_acts_cur[i] > 0.0],
],
dim=0,
)

max_feature_acts = torch.max(
max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values
)

n_tokens_current = torch.tensor(batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int)
n_tokens_current = torch.tensor(
batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int
)
n_training_tokens += cast(int, n_tokens_current.item())
n_training_steps += 1

pbar.update(n_tokens_current.item())

pbar.close()

sample_result = {k1: {
k2: rearrange(v2, 'n_samples d_sae ... -> d_sae n_samples ...') for k2, v2 in v1.items()
} for k1, v1 in sample_result.items()}
sample_result = {
k1: {
k2: rearrange(v2, "n_samples d_sae ... -> d_sae n_samples ...")
for k2, v2 in v1.items()
}
for k1, v1 in sample_result.items()
}

result = {
"index": torch.arange(start_index, end_index, device=cfg.sae.device, dtype=torch.int32),
"index": torch.arange(
start_index, end_index, device=cfg.sae.device, dtype=torch.int32
),
"act_times": act_times,
"feature_acts_all": feature_acts_all,
"max_feature_acts": max_feature_acts,
Expand All @@ -140,8 +224,9 @@ def sample_feature_activations(
"name": k,
"feature_acts": v["feature_acts"],
"contexts": v["contexts"],
} for k, v in sample_result.items()
}
for k, v in sample_result.items()
],
}

return result
return result
27 changes: 13 additions & 14 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def __post_init__(self):
class RunnerConfig(BaseConfig):
exp_name: str = "test"
exp_series: Optional[str] = None
exp_result_dir: str = "results"
exp_result_path: str = "results"

def __post_init__(self):
super().__post_init__()
if is_master():
os.makedirs(self.exp_result_dir, exist_ok=True)
os.makedirs(os.path.join(self.exp_result_dir, self.exp_name), exist_ok=True)
os.makedirs(self.exp_result_path, exist_ok=True)


@dataclass(kw_only=True)
Expand Down Expand Up @@ -257,12 +256,12 @@ def from_pretrained(
@deprecated("Use from_pretrained and to_dict instead.")
@staticmethod
def get_hyperparameters(
exp_name: str, exp_result_dir: str, ckpt_name: str, strict_loading: bool = True
exp_result_path: str, ckpt_name: str, strict_loading: bool = True
) -> dict[str, Any]:
with open(os.path.join(exp_result_dir, exp_name, "hyperparams.json"), "r") as f:
with open(os.path.join(exp_result_path, "hyperparams.json"), "r") as f:
hyperparams = json.load(f)
hyperparams["sae_pretrained_name_or_path"] = os.path.join(
exp_result_dir, exp_name, "checkpoints", ckpt_name
exp_result_path, "checkpoints", ckpt_name
)
hyperparams["strict_loading"] = strict_loading
# Remove non-hyperparameters from the dict
Expand Down Expand Up @@ -350,13 +349,13 @@ def __post_init__(self):
super().__post_init__()

if is_master():
# if os.path.exists(
# os.path.join(self.exp_result_dir, self.exp_name, "checkpoints")
# ):
# raise ValueError(
# f"Checkpoints for experiment {self.exp_name} already exist. Consider changing the experiment name."
# )
os.makedirs(os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"), exist_ok=True)
if os.path.exists(
os.path.join(self.exp_result_path, "checkpoints")
):
raise ValueError(
f"Checkpoints for experiment {self.exp_result_path} already exist. Consider changing the experiment name."
)
os.makedirs(os.path.join(self.exp_result_path, "checkpoints"))

self.effective_batch_size = self.train_batch_size * self.sae.ddp_size
print_once(f"Effective batch size: {self.effective_batch_size}")
Expand Down Expand Up @@ -418,7 +417,7 @@ def __post_init__(self):

if is_master():
os.makedirs(
os.path.join(self.exp_result_dir, self.exp_name, "checkpoints"),
os.path.join(self.exp_result_path, "checkpoints"),
exist_ok=True,
)

Expand Down
Loading

0 comments on commit 2d23c75

Please sign in to comment.