Skip to content

Commit

Permalink
feature(MEGA UPDATE): implement options in Anthropic April Update e.g…
Browse files Browse the repository at this point in the history
…. combine decoder bias norm with L1
  • Loading branch information
Hzfinfdu committed Jul 1, 2024
1 parent 6f11837 commit b8aba9e
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 63 deletions.
91 changes: 69 additions & 22 deletions examples/programmatic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,108 @@
from lm_saes.config import LanguageModelSAETrainingConfig
from lm_saes.runner import language_model_sae_runner

# import argparse


# parser = argparse.ArgumentParser()
# parser.add_argument("--lr", type=float, default=4e-4)
# parser.add_argument("--l1_coef", type=float, default=8e-5)
# parser.add_argument("--sparsity_include_decoder_norm", action="store_true")
# parser.add_argument("--remove_gradient_parallel_to_decoder_directions", action="store_true")
# parser.add_argument("--use_decoder_bias", action="store_true")
# parser.add_argument("--init_encoder_with_decoder_transpose", action="store_true")
# args = parser.parse_args()

# lr = args.lr
# l1_coefficient = args.l1_coef
# sparsity_include_decoder_norm = args.sparsity_include_decoder_norm
# remove_gradient_parallel_to_decoder_directions = args.remove_gradient_parallel_to_decoder_directions
# use_decoder_bias = args.use_decoder_bias
# init_encoder_with_decoder_transpose = args.init_encoder_with_decoder_transpose

cfg = LanguageModelSAETrainingConfig.from_flattened(dict(
# LanguageModelConfig
model_name = "gpt2", # The model name or path for the pre-trained model.
d_model = 768, # The hidden size of the model.
model_name = "meta-llama/Meta-Llama-3-8B-Instruct", # The model name or path for the pre-trained model.
model_from_pretrained_path="/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct", # The path to load the pre-trained model.
d_model = 4096, # The hidden size of the model.

# TextDatasetConfig
dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
dataset_path = [
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_arxiv_500k",
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_c4_500k",
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_stack_500k",
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_book_500k",
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_pile_500k",
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/Pretrain_RedPajama_wiki_500k",
"/remote-home/share/personal/zfhe/projects/SAE_data_at_scale/data/SFT_WildChatClean",
], # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
sample_probs = [1, 8, 1, 2, 4, 4, 1],
is_dataset_tokenized = False, # Whether the dataset is tokenized.
is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training.
concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
context_size = 256, # The sequence length of the text dataset.
store_batch_size = 32, # The batch size for loading the corpus.
concat_tokens = [
True,
True,
True,
True,
True,
True,
False,
], # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
context_size = 1024, # The sequence length of the text dataset.
store_batch_size = 20, # The batch size for loading the corpus.

# ActivationStoreConfig
hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
hook_points = ['blocks.16.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`)
n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary.

# SAEConfig
hook_point_in = "blocks.3.hook_mlp_out",
hook_point_out = "blocks.3.hook_mlp_out",
hook_point_in = 'blocks.16.hook_resid_pre',
hook_point_out = 'blocks.16.hook_resid_pre',
use_decoder_bias = True, # Whether to use decoder bias.
expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model.
norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none".
decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm.
decoder_exactly_fixed_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm.
use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder.
l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations.
l1_coefficient = 2e-4, # The L1 regularization coefficient for the feature activations.
l1_coefficient_warmup_steps = 10000, # The number of warm-up steps for the L1 regularization coefficient.
lp = 1, # The p-norm to use for the L1 regularization.
use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features.
use_ghost_grads = False, # Whether to use the ghost gradients for saving dead features.
init_decoder_norm = 'auto',
init_encoder_with_decoder_transpose = True,
apply_decoder_bias_to_pre_encoder = True,
sparsity_include_decoder_norm = True,

# LanguageModelSAETrainingConfig
total_training_tokens = 1_600_000_000, # The total number of tokens to train the dictionary.
lr = 4e-4, # The learning rate for the dictionary training.
betas = (0, 0.9999), # The betas for the Adam optimizer.
total_training_tokens = 300_000_000, # The total number of tokens to train the dictionary.
lr = 1e-4, # The learning rate for the dictionary training.
betas = (0.9, 0.9999), # The betas for the Adam optimizer.

lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup".
lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate.
lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler.
lr_warm_up_steps = 2000, # The number of warm-up steps for the learning rate.
lr_cool_down_steps = 4000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler.
clip_grad_norm = 0.0, # The maximum gradient norm for clipping. If 0.0, no gradient clipping will be performed.
train_batch_size = 4096, # The batch size for training the dictionary, i.e. the number of token activations in a batch.
feature_sampling_window = 1000, # The window size for sampling the feature activations.
dead_feature_window = 5000, # The window size for detecting the dead features.
dead_feature_threshold = 1e-6, # The threshold for detecting the dead features.
eval_frequency = 1000, # The step frequency for evaluating the dictionary.
log_frequency = 100, # The step frequency for logging the training information (to wandb).
n_checkpoints = 10, # The number of checkpoints to save during the training.
remove_gradient_parallel_to_decoder_directions = False,


# WandbConfig
log_to_wandb = True, # Whether to log the training information to wandb.
wandb_project= "gpt2-sae", # The wandb project name.
wandb_project= "llama3-sae", # The wandb project name.

# RunnerConfig
device = "cuda", # The device to place all torch tensors.
seed = 42, # The random seed.
dtype = torch.float32, # The torch data type of non-integer tensors.
dtype = torch.bfloat16, # The torch data type of non-integer tensors.

exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "default",
exp_name = f"llama-L16RPr", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "llama3-sae",
exp_result_dir = "results"
))

Expand Down
32 changes: 26 additions & 6 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,28 +163,46 @@ class SAEConfig(BaseModelConfig):
sae_pretrained_name_or_path: Optional[str] = None
strict_loading: bool = True

use_decoder_bias: bool = False
use_decoder_bias: bool = True
apply_decoder_bias_to_pre_encoder: bool = True # set to False when training transcoders
decoder_bias_init_method: str = "geometric_median"
expansion_factor: int = 32
expansion_factor: int = 128
d_model: int = 768
d_sae: int = None # type: ignore
""" The dimension of the SAE, i.e. the number of dictionary components (or features). If None, it will be set to d_model * expansion_factor """
norm_activation: str = "token-wise" # none, token-wise, batch-wise
decoder_exactly_unit_norm: bool = True
norm_activation: str = "token-wise" # none, token-wise, batch-wise, dataset-wise
dataset_average_activation_norm: Dict[str, float] | None = None
decoder_exactly_fixed_norm: bool = False
sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm
use_glu_encoder: bool = False
init_decoder_norm: float | str = 'auto'
init_encoder_norm: float | str | None = 'auto'
init_encoder_with_decoder_transpose: bool = True

l1_coefficient: float = 0.00008
l1_coefficient_warmup_steps: int = 0
lp: int = 1

use_ghost_grads: bool = True
use_ghost_grads: bool = False

def __post_init__(self):
super().__post_init__()
if self.hook_point_out is None:
self.hook_point_out = self.hook_point_in
if self.d_sae is None:
self.d_sae = self.d_model * self.expansion_factor
if self.norm_activation == "dataset-wise" and self.dataset_average_activation_norm is None:
print(f'dataset_average_activation_norm is None and norm_activation is "dataset-wise". Will be computed automatically from the dataset.')
if self.sparsity_include_decoder_norm and self.decoder_exactly_fixed_norm:
raise ValueError("sparsity_include_decoder_norm and decoder_exactly_fixed_norm are incompatible.")
if self.sparsity_include_decoder_norm and self.use_ghost_grads:
raise ValueError("sparsity_include_decoder_norm and use_ghost_grads are incompatible.")
if isinstance(self.init_decoder_norm, str) and self.init_decoder_norm != 'auto':
raise ValueError("init_decoder_norm must be a float or 'auto'.")
if isinstance(self.init_encoder_norm, str) and self.init_encoder_norm != 'auto':
raise ValueError("init_encoder_norm must be None, a float or 'auto'.")
if self.init_encoder_with_decoder_transpose and isinstance(self.init_encoder_norm, float):
raise ValueError("init_encoder_with_decoder_transpose and init_encoder_norm with float are incompatible.")


@staticmethod
def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
Expand Down Expand Up @@ -269,6 +287,8 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig):
lr_warm_up_steps: int = 5000
lr_cool_down_steps: int = 10000
train_batch_size: int = 4096
clip_grad_norm: float = 0.0
remove_gradient_parallel_to_decoder_directions: bool = False

finetuning: bool = False

Expand Down
31 changes: 22 additions & 9 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,13 @@
from lm_saes.sae import SparseAutoEncoder
from lm_saes.activation.activation_dataset import make_activation_dataset
from lm_saes.activation.activation_store import ActivationStore
from lm_saes.sae_training import prune_sae, train_sae
from lm_saes.sae_training import prune_sae, train_sae, init_sae_on_dataset
from lm_saes.analysis.sample_feature_activations import sample_feature_activations
from lm_saes.analysis.features_to_logits import features_to_logits


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down Expand Up @@ -72,6 +65,26 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm == 'auto'
):
assert not cfg.finetuning
sae = init_sae_on_dataset(
model,
activation_store,
cfg
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0):
wandb_config: dict = {
**asdict(cfg),
Expand Down
Loading

0 comments on commit b8aba9e

Please sign in to comment.