Skip to content

Commit

Permalink
Merge branch 'main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankstein73 committed Jul 31, 2024
2 parents b067b07 + 173f434 commit 0c0cc2d
Show file tree
Hide file tree
Showing 20 changed files with 327 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(
self.apply_rotary(k, 0, attention_mask)
) # keys are cached so no offset

if self.cfg.dtype not in [torch.float32, torch.float64]:
if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]:
# If using 16 bits, increase the precision to avoid numerical instabilities
q = q.to(torch.float32)
k = k.to(torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion TransformerLens/transformer_lens/components/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[i
def forward(
self, x: Float[torch.Tensor, "batch pos length"]
) -> Float[torch.Tensor, "batch pos length"]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]:
x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
def forward(
self, x: Float[torch.Tensor, "batch pos length"]
) -> Float[torch.Tensor, "batch pos length"]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]:
x = x.to(torch.float32)

scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
Expand Down
21 changes: 21 additions & 0 deletions TransformerLens/transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@
"CodeLlama-7b-Python-hf",
"CodeLlama-7b-Instruct-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3.1-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
"Baidicoot/Othello-GPT-Transformer-Lens",
Expand Down Expand Up @@ -809,6 +811,25 @@ def convert_hf_model_config(model_name: str, **kwargs):
"final_rms": True,
"gated_mlp": True,
}
elif "Meta-Llama-3.1-8B" in official_model_name:
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 14336,
"n_layers": 32,
"n_ctx": 8192,
"eps": 1e-5,
"d_vocab": 128256,
"act_fn": "silu",
"n_key_value_heads": 8,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
}
elif "Meta-Llama-3-70B" in official_model_name:
cfg_dict = {
"d_model": 8192,
Expand Down
10 changes: 5 additions & 5 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ authors = [
]
dependencies = [
"datasets>=2.17.0",
"transformers>=4.43.0",
"einops>=0.7.0",
"fastapi>=0.110.0",
"matplotlib>=3.8.3",
Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def _process_dataset(dataset_path: str, cfg: TextDatasetConfig):
if dist.is_initialized():
shard_id = dist.get_rank()
shard = dataset.shard(
num_shards=dist.get_world_size(), index=shard_id
num_shards=dist.get_world_size(), index=shard_id, contiguous=True
)
else:
shard = dataset


dataloader = DataLoader(shard, batch_size=cfg.store_batch_size)
dataloader = DataLoader(shard, batch_size=cfg.store_batch_size, num_workers=4, prefetch_factor=4, pin_memory=True)
return dataloader

@staticmethod
Expand Down
13 changes: 13 additions & 0 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,19 @@ def __post_init__(self):
print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}")


if self.lr_scheduler_name == "constantwithwarmup" and isinstance(self.lr_warm_up_steps, float):
assert 0 <= self.lr_warm_up_steps <= 1.0
self.lr_warm_up_steps = int(self.lr_warm_up_steps * total_training_steps)
print_once(f"Learning rate warm up steps: {self.lr_warm_up_steps}")
if isinstance(self.sae.l1_coefficient_warmup_steps, float):
assert 0 <= self.sae.l1_coefficient_warmup_steps <= 1.0
self.sae.l1_coefficient_warmup_steps = int(self.sae.l1_coefficient_warmup_steps * total_training_steps)
print_once(f"L1 coefficient warm up steps: {self.sae.l1_coefficient_warmup_steps}")
if isinstance(self.lr_cool_down_steps, float):
assert 0 <= self.lr_cool_down_steps <= 1.0
self.lr_cool_down_steps = int(self.lr_cool_down_steps * total_training_steps)
print_once(f"Learning rate cool down steps: {self.lr_cool_down_steps}")

@dataclass(kw_only=True)
class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig):
"""
Expand Down
63 changes: 27 additions & 36 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from lm_saes.utils.misc import is_master

from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
loss_parallel,
)
from torch.distributed._tensor import (
DTensor,
Shard,
Replicate,
distribute_module,
distribute_tensor,
)


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
if is_master():
Expand Down Expand Up @@ -77,43 +90,7 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
assert not cfg.finetuning
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
assert not cfg.finetuning
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and is_master():
wandb_config: dict = {
Expand Down Expand Up @@ -323,6 +300,20 @@ def activation_generation_runner(cfg: ActivationGenerationConfig):
def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig):
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.sae.tp_size > 1:
plan = {
"encoder": ColwiseParallel(output_layouts=Replicate()),
}
if cfg.sae.use_glu_encoder:
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore
sae.parallelize_plan = plan

sae.decoder.weight = None # type: ignore[assignment]
torch.cuda.empty_cache()



hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down
50 changes: 39 additions & 11 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from builtins import print
from importlib.metadata import version
import os
from typing import Dict, Literal, Union, overload, List
Expand Down Expand Up @@ -60,9 +61,10 @@ def __init__(self, cfg: SAEConfig):
)
torch.nn.init.kaiming_uniform_(self.encoder.weight)
torch.nn.init.zeros_(self.encoder.bias)
self.device_mesh = init_device_mesh(
"cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp")
)
if cfg.tp_size > 1 or cfg.ddp_size > 1:
self.device_mesh = init_device_mesh(
"cuda", (cfg.ddp_size, cfg.tp_size), mesh_dim_names=("ddp", "tp")
)

if cfg.use_glu_encoder:

Expand Down Expand Up @@ -116,7 +118,7 @@ def initialize_parameters(self):
if self.cfg.init_encoder_with_decoder_transpose:
self.encoder.weight.data = self.decoder.weight.data.T.clone().contiguous()
else:
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm)
self.set_encoder_norm_to_fixed_norm(self.cfg.init_encoder_norm, during_init=True)

def train_base_parameters(self):
"""Set the base parameters to be trained."""
Expand Down Expand Up @@ -480,6 +482,13 @@ def set_decoder_norm_to_fixed_norm(
decoder_norm = self.decoder_norm(keepdim=True, during_init=during_init)
if force_exact is None:
force_exact = self.cfg.decoder_exactly_fixed_norm


if self.cfg.tp_size > 1 and not during_init:
decoder_norm = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)

if force_exact:
self.decoder.weight.data = self.decoder.weight.data * value / decoder_norm
else:
Expand All @@ -489,15 +498,19 @@ def set_decoder_norm_to_fixed_norm(
)

@torch.no_grad()
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0):
def set_encoder_norm_to_fixed_norm(self, value: float | None = 1.0, during_init: bool = False):
if self.cfg.use_glu_encoder:
raise NotImplementedError("GLU encoder not supported")
if value is None:
print(
f"Encoder norm is not set to a fixed value, using random initialization."
)
return
encoder_norm = self.encoder_norm(keepdim=True)
encoder_norm = self.encoder_norm(keepdim=True, during_init=during_init)
if self.cfg.tp_size > 1 and not during_init:
encoder_norm = distribute_tensor(
encoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
self.encoder.weight.data = self.encoder.weight.data * value / encoder_norm

@torch.no_grad()
Expand All @@ -514,10 +527,25 @@ def transform_to_unit_decoder_norm(self):
raise NotImplementedError("GLU encoder not supported")

decoder_norm = self.decoder_norm() # (d_sae,)
self.encoder.weight.data = self.encoder.weight.data * decoder_norm[:, None]
self.decoder.weight.data = self.decoder.weight.data / decoder_norm
if self.cfg.tp_size > 1:
decoder_norm_en = distribute_tensor(
decoder_norm[:, None], device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
decoder_norm_de = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
dencoder_norm_bias = distribute_tensor(
decoder_norm, device_mesh=self.device_mesh["tp"], placements=[Replicate()]
)
else:
decoder_norm_en = decoder_norm[:, None]
decoder_norm_de = decoder_norm
dencoder_norm_bias = decoder_norm

self.encoder.weight.data = self.encoder.weight.data * decoder_norm_en
self.decoder.weight.data = self.decoder.weight.data / decoder_norm_de

self.encoder.bias.data = self.encoder.bias.data * decoder_norm
self.encoder.bias.data = self.encoder.bias.data * dencoder_norm_bias

@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
Expand Down Expand Up @@ -624,8 +652,8 @@ def from_initialization_searching(
cfg: LanguageModelSAETrainingConfig,
):
test_batch = activation_store.next(
batch_size=cfg.train_batch_size * 8
) # just random hard code xd
batch_size=cfg.train_batch_size
)
activation_in, activation_out = test_batch[cfg.sae.hook_point_in], test_batch[cfg.sae.hook_point_out] # type: ignore

if (
Expand Down
16 changes: 9 additions & 7 deletions src/lm_saes/sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def train_sae(
if cfg.finetuning:
loss = loss_data["l_rec"].mean()
loss.backward()
grad_norm = torch.tensor([0.0], device=cfg.sae.device)
if cfg.clip_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm)
grad_norm = torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg.clip_grad_norm)
if cfg.remove_gradient_parallel_to_decoder_directions:
sae.remove_gradient_parallel_to_decoder_directions()
optimizer.step()
Expand All @@ -171,13 +172,13 @@ def train_sae(
if cfg.wandb.log_to_wandb and (is_master()):
feature_sparsity = act_freq_scores / n_frac_active_tokens
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
wandb_histogram = wandb.Histogram(
log_feature_sparsity.detach().cpu().float().numpy()
)
# wandb_histogram = wandb.Histogram(
# log_feature_sparsity.detach().cpu().float().numpy()
# )
wandb.log(
{
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
"plots/feature_density_line_chart": wandb_histogram,
# "plots/feature_density_line_chart": wandb_histogram,
"sparsity/below_1e-5": (feature_sparsity < 1e-5)
.sum()
.item(),
Expand Down Expand Up @@ -285,8 +286,9 @@ def train_sae(
# norm
"metrics/decoder_norm": decoder_norm.item(),
"metrics/encoder_norm": encoder_norm.item(),
"metrics/decoder_bias_mean": sae.decoder.bias.mean().item() if sae.cfg.use_decoder_bias else 0,
"metrics/enocder_bias_mean": sae.encoder.bias.mean().item(),
"metrics/decoder_bias_norm": sae.decoder.bias.norm().item() if sae.cfg.use_decoder_bias else 0,
"metrics/encoder_bias_norm": sae.encoder.bias.norm().item(),
"metrics/gradients_norm": grad_norm.item(),
# sparsity
"sparsity/l1_coefficient": sae.current_l1_coefficient,
"sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(),
Expand Down
12 changes: 12 additions & 0 deletions src/lm_saes/utils/convert_pre_enc_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from lm_saes.sae import SparseAutoEncoder
import torch


@torch.no_grad()
def merge_pre_enc_bias_to_enc_bias(sae: SparseAutoEncoder):
assert sae.cfg.apply_decoder_bias_to_pre_encoder

sae.cfg.apply_decoder_bias_to_pre_encoder = False
sae.encoder.bias.data = sae.encoder.bias.data - sae.encoder.weight.data @ sae.decoder.bias.data

return sae
Loading

0 comments on commit 0c0cc2d

Please sign in to comment.