Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into 11-proposal-accelerate-inference-in-transformerlens
  • Loading branch information
StarConnor committed Jul 2, 2024
2 parents 5447661 + 5ef3cba commit 90a3361
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 62 deletions.
51 changes: 31 additions & 20 deletions examples/programmatic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,72 @@
from lm_saes.config import LanguageModelSAETrainingConfig
from lm_saes.runner import language_model_sae_runner


cfg = LanguageModelSAETrainingConfig.from_flattened(dict(
# LanguageModelConfig
model_name = "gpt2", # The model name or path for the pre-trained model.
d_model = 768, # The hidden size of the model.

# TextDatasetConfig
dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
dataset_path = 'Skylion007/OpenWebText', # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
is_dataset_tokenized = False, # Whether the dataset is tokenized.
is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training.
concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
context_size = 256, # The sequence length of the text dataset.
store_batch_size = 32, # The batch size for loading the corpus.
concat_tokens = True, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
context_size = 1024, # The sequence length of the text dataset.
store_batch_size = 20, # The batch size for loading the corpus.

# ActivationStoreConfig
hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
hook_points = ['blocks.8.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`)
n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary.

# SAEConfig
hook_point_in = "blocks.3.hook_mlp_out",
hook_point_out = "blocks.3.hook_mlp_out",
expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model.
hook_point_in = 'blocks.8.hook_resid_pre',
hook_point_out = 'blocks.8.hook_resid_pre',
use_decoder_bias = True, # Whether to use decoder bias.
expansion_factor = 128, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model.
norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none".
decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm.
decoder_exactly_fixed_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm.
use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder.
l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations.
l1_coefficient = 2e-4, # The L1 regularization coefficient for the feature activations.
l1_coefficient_warmup_steps = 10000, # The number of warm-up steps for the L1 regularization coefficient.
lp = 1, # The p-norm to use for the L1 regularization.
use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features.
use_ghost_grads = False, # Whether to use the ghost gradients for saving dead features.
init_decoder_norm = None, # The initial norm of the decoder. If None, the decoder will be initialized automatically with the lowest MSE.
init_encoder_with_decoder_transpose = True,
apply_decoder_bias_to_pre_encoder = True,
sparsity_include_decoder_norm = True,

# LanguageModelSAETrainingConfig
total_training_tokens = 1_600_000_000, # The total number of tokens to train the dictionary.
lr = 4e-4, # The learning rate for the dictionary training.
betas = (0, 0.9999), # The betas for the Adam optimizer.
total_training_tokens = 100_000_000, # The total number of tokens to train the dictionary.
lr = 1e-4, # The learning rate for the dictionary training.
betas = (0.9, 0.9999), # The betas for the Adam optimizer.

lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup".
lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate.
lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler.
lr_warm_up_steps = 2000, # The number of warm-up steps for the learning rate.
lr_cool_down_steps = 4000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler.
clip_grad_norm = 0.0, # The maximum gradient norm for clipping. If 0.0, no gradient clipping will be performed.
train_batch_size = 4096, # The batch size for training the dictionary, i.e. the number of token activations in a batch.
feature_sampling_window = 1000, # The window size for sampling the feature activations.
dead_feature_window = 5000, # The window size for detecting the dead features.
dead_feature_threshold = 1e-6, # The threshold for detecting the dead features.
eval_frequency = 1000, # The step frequency for evaluating the dictionary.
log_frequency = 100, # The step frequency for logging the training information (to wandb).
n_checkpoints = 10, # The number of checkpoints to save during the training.
remove_gradient_parallel_to_decoder_directions = False,


# WandbConfig
log_to_wandb = True, # Whether to log the training information to wandb.
wandb_project= "gpt2-sae", # The wandb project name.
wandb_project= "test", # The wandb project name.

# RunnerConfig
device = "cuda", # The device to place all torch tensors.
seed = 42, # The random seed.
dtype = torch.float32, # The torch data type of non-integer tensors.

exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "default",
exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "test",
exp_result_dir = "results"
))

Expand Down
28 changes: 22 additions & 6 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,28 +164,42 @@ class SAEConfig(BaseModelConfig):
sae_pretrained_name_or_path: Optional[str] = None
strict_loading: bool = True

use_decoder_bias: bool = False
use_decoder_bias: bool = True
apply_decoder_bias_to_pre_encoder: bool = True # set to False when training transcoders
decoder_bias_init_method: str = "geometric_median"
expansion_factor: int = 32
expansion_factor: int = 128
d_model: int = 768
d_sae: int = None # type: ignore
""" The dimension of the SAE, i.e. the number of dictionary components (or features). If None, it will be set to d_model * expansion_factor """
norm_activation: str = "token-wise" # none, token-wise, batch-wise
decoder_exactly_unit_norm: bool = True
norm_activation: str = "token-wise" # none, token-wise, batch-wise, dataset-wise
dataset_average_activation_norm: Dict[str, float] | None = None
decoder_exactly_fixed_norm: bool = False
sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm
use_glu_encoder: bool = False
init_decoder_norm: float | None = None # type: ignore
init_encoder_norm: float | None = None # type: ignore
init_encoder_with_decoder_transpose: bool = True

l1_coefficient: float = 0.00008
l1_coefficient_warmup_steps: int = 0
lp: int = 1

use_ghost_grads: bool = True
use_ghost_grads: bool = False

def __post_init__(self):
super().__post_init__()
if self.hook_point_out is None:
self.hook_point_out = self.hook_point_in
if self.d_sae is None:
self.d_sae = self.d_model * self.expansion_factor
if self.norm_activation == "dataset-wise" and self.dataset_average_activation_norm is None:
print(f'dataset_average_activation_norm is None and norm_activation is "dataset-wise". Will be computed automatically from the dataset.')
if self.sparsity_include_decoder_norm and self.decoder_exactly_fixed_norm:
raise ValueError("sparsity_include_decoder_norm and decoder_exactly_fixed_norm are incompatible.")
if self.sparsity_include_decoder_norm and self.use_ghost_grads:
raise ValueError("sparsity_include_decoder_norm and use_ghost_grads are incompatible.")
if self.init_encoder_with_decoder_transpose and isinstance(self.init_encoder_norm, float):
raise ValueError("init_encoder_with_decoder_transpose and init_encoder_norm with float are incompatible.")


@staticmethod
def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
Expand Down Expand Up @@ -270,6 +284,8 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig):
lr_warm_up_steps: int = 5000
lr_cool_down_steps: int = 10000
train_batch_size: int = 4096
clip_grad_norm: float = 0.0
remove_gradient_parallel_to_decoder_directions: bool = False

finetuning: bool = False

Expand Down
28 changes: 20 additions & 8 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,8 @@
from lm_saes.analysis.features_to_logits import features_to_logits


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down Expand Up @@ -73,6 +66,25 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
assert not cfg.finetuning
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0):
wandb_config: dict = {
**asdict(cfg),
Expand Down
Loading

0 comments on commit 90a3361

Please sign in to comment.