Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

April tricks #28

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions examples/programmatic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,72 @@
from lm_saes.config import LanguageModelSAETrainingConfig
from lm_saes.runner import language_model_sae_runner


cfg = LanguageModelSAETrainingConfig.from_flattened(dict(
# LanguageModelConfig
model_name = "gpt2", # The model name or path for the pre-trained model.
d_model = 768, # The hidden size of the model.

# TextDatasetConfig
dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
dataset_path = 'Skylion007/OpenWebText', # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
is_dataset_tokenized = False, # Whether the dataset is tokenized.
is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training.
concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
context_size = 256, # The sequence length of the text dataset.
store_batch_size = 32, # The batch size for loading the corpus.
concat_tokens = True, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
context_size = 1024, # The sequence length of the text dataset.
store_batch_size = 20, # The batch size for loading the corpus.

# ActivationStoreConfig
hook_points = ["blocks.3.hook_mlp_out"], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
hook_points = ['blocks.8.hook_resid_pre'], # Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly.
use_cached_activations = False, # Whether to use cached activations. Caching activation is now not recommended, as it may consume extremely large disk space. (May be tens of TBs for corpus like `openwebtext`)
n_tokens_in_buffer = 500_000, # The number of tokens to store in the activation buffer. The buffer is used to shuffle the activations before training the dictionary.

# SAEConfig
hook_point_in = "blocks.3.hook_mlp_out",
hook_point_out = "blocks.3.hook_mlp_out",
expansion_factor = 32, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model.
hook_point_in = 'blocks.8.hook_resid_pre',
hook_point_out = 'blocks.8.hook_resid_pre',
use_decoder_bias = True, # Whether to use decoder bias.
expansion_factor = 128, # The expansion factor of the dictionary. d_sae = expansion_factor * d_model.
norm_activation = "token-wise", # The normalization method for the activations. Can be "token-wise", "batch-wise" or "none".
decoder_exactly_unit_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm.
decoder_exactly_fixed_norm = False, # Whether to enforce the decoder to have exactly unit norm. If False, the decoder will have less than or equal to unit norm.
use_glu_encoder = False, # Whether to use the Gated Linear Unit (GLU) for the encoder.
l1_coefficient = 1.2e-4, # The L1 regularization coefficient for the feature activations.
l1_coefficient = 2e-4, # The L1 regularization coefficient for the feature activations.
l1_coefficient_warmup_steps = 10000, # The number of warm-up steps for the L1 regularization coefficient.
lp = 1, # The p-norm to use for the L1 regularization.
use_ghost_grads = True, # Whether to use the ghost gradients for saving dead features.
use_ghost_grads = False, # Whether to use the ghost gradients for saving dead features.
init_decoder_norm = None, # The initial norm of the decoder. If None, the decoder will be initialized automatically with the lowest MSE.
init_encoder_with_decoder_transpose = True,
apply_decoder_bias_to_pre_encoder = True,
sparsity_include_decoder_norm = True,

# LanguageModelSAETrainingConfig
total_training_tokens = 1_600_000_000, # The total number of tokens to train the dictionary.
lr = 4e-4, # The learning rate for the dictionary training.
betas = (0, 0.9999), # The betas for the Adam optimizer.
total_training_tokens = 100_000_000, # The total number of tokens to train the dictionary.
lr = 1e-4, # The learning rate for the dictionary training.
betas = (0.9, 0.9999), # The betas for the Adam optimizer.

lr_scheduler_name = "constantwithwarmup", # The learning rate scheduler name. Can be "constant", "constantwithwarmup", "linearwarmupdecay", "cosineannealing", "cosineannealingwarmup" or "exponentialwarmup".
lr_warm_up_steps = 5000, # The number of warm-up steps for the learning rate.
lr_cool_down_steps = 10000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler.
lr_warm_up_steps = 2000, # The number of warm-up steps for the learning rate.
lr_cool_down_steps = 4000, # The number of cool-down steps for the learning rate. Currently only used for the "constantwithwarmup" scheduler.
clip_grad_norm = 0.0, # The maximum gradient norm for clipping. If 0.0, no gradient clipping will be performed.
train_batch_size = 4096, # The batch size for training the dictionary, i.e. the number of token activations in a batch.
feature_sampling_window = 1000, # The window size for sampling the feature activations.
dead_feature_window = 5000, # The window size for detecting the dead features.
dead_feature_threshold = 1e-6, # The threshold for detecting the dead features.
eval_frequency = 1000, # The step frequency for evaluating the dictionary.
log_frequency = 100, # The step frequency for logging the training information (to wandb).
n_checkpoints = 10, # The number of checkpoints to save during the training.
remove_gradient_parallel_to_decoder_directions = False,


# WandbConfig
log_to_wandb = True, # Whether to log the training information to wandb.
wandb_project= "gpt2-sae", # The wandb project name.
wandb_project= "test", # The wandb project name.

# RunnerConfig
device = "cuda", # The device to place all torch tensors.
seed = 42, # The random seed.
dtype = torch.float32, # The torch data type of non-integer tensors.

exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "default",
exp_name = f"test", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "test",
exp_result_dir = "results"
))

Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/activation/activation_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig):
self.cfg = cfg

def next(self) -> Dict[str, torch.Tensor] | None:
tokens = self.token_source.next(self.cfg.dataset.store_batch_size)
tokens = self.next_tokens(self.cfg.dataset.store_batch_size)

if tokens is None:
return None
Expand Down
9 changes: 6 additions & 3 deletions src/lm_saes/activation/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
concat_tokens: list[bool],
seq_len: int,
sample_probs: list[float],
prepend_bos: list[bool]
):
self.dataloader = dataloader
self.model = model
Expand All @@ -33,13 +34,14 @@ def __init__(
self.resid = torch.tensor([], dtype=torch.long, device=self.device)

self.sample_probs = sample_probs
self.prepend_bos = prepend_bos


def fill_with_one_batch(self, batch, pack) -> None:
def fill_with_one_batch(self, batch, pack: bool, prepend_bos: bool) -> None:
if self.is_dataset_tokenized:
tokens: torch.Tensor = batch["tokens"].to(self.device)
else:
tokens = self.model.to_tokens(batch["text"], prepend_bos=False).to(self.device)
tokens = self.model.to_tokens(batch["text"], prepend_bos=prepend_bos).to(self.device)
if pack:
while tokens.size(0) > 0:
cur_tokens = tokens[0]
Expand Down Expand Up @@ -81,7 +83,7 @@ def next(self, batch_size: int) -> torch.Tensor | None:
else:
return None

self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch])
self.fill_with_one_batch(batch, self.concat_tokens[dataset_idx_to_fetch], prepend_bos=self.prepend_bos[dataset_idx_to_fetch])

ret = self.token_buffer[:batch_size]
self.token_buffer = self.token_buffer[batch_size:]
Expand Down Expand Up @@ -120,4 +122,5 @@ def from_config(model: HookedTransformer, cfg: TextDatasetConfig):
concat_tokens=cfg.concat_tokens,
seq_len=cfg.context_size,
sample_probs=cfg.sample_probs,
prepend_bos=cfg.prepend_bos
)
33 changes: 27 additions & 6 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class TextDatasetConfig(RunnerConfig):
context_size: int = 128
store_batch_size: int = 64
sample_probs: List[float] = field(default_factory=lambda: [1.0])
prepend_bos: List[bool] = field(default_factory=lambda: [False])

def __post_init__(self):
super().__post_init__()
Expand All @@ -115,10 +116,14 @@ def __post_init__(self):
if isinstance(self.concat_tokens, bool):
self.concat_tokens = [self.concat_tokens]

if isinstance(self.prepend_bos, bool):
self.prepend_bos = [self.prepend_bos]

self.sample_probs = [p / sum(self.sample_probs) for p in self.sample_probs]

assert len(self.sample_probs) == len(self.dataset_path), "Number of sample_probs must match number of dataset paths"
assert len(self.concat_tokens) == len(self.dataset_path), "Number of concat_tokens must match number of dataset paths"
assert len(self.prepend_bos) == len(self.dataset_path), "Number of prepend_bos must match number of dataset paths"


@dataclass(kw_only=True)
Expand Down Expand Up @@ -163,28 +168,42 @@ class SAEConfig(BaseModelConfig):
sae_pretrained_name_or_path: Optional[str] = None
strict_loading: bool = True

use_decoder_bias: bool = False
use_decoder_bias: bool = True
apply_decoder_bias_to_pre_encoder: bool = True # set to False when training transcoders
decoder_bias_init_method: str = "geometric_median"
expansion_factor: int = 32
expansion_factor: int = 128
d_model: int = 768
d_sae: int = None # type: ignore
""" The dimension of the SAE, i.e. the number of dictionary components (or features). If None, it will be set to d_model * expansion_factor """
norm_activation: str = "token-wise" # none, token-wise, batch-wise
decoder_exactly_unit_norm: bool = True
norm_activation: str = "token-wise" # none, token-wise, batch-wise, dataset-wise
dataset_average_activation_norm: Dict[str, float] | None = None
decoder_exactly_fixed_norm: bool = False
sparsity_include_decoder_norm: bool = True # set to True: sparsity loss = sum(act * corresponding_decoder_norm), otherwise loss = sum(act). Incompatible with decoder_exactly_fixed_norm
use_glu_encoder: bool = False
init_decoder_norm: float | None = None # type: ignore
init_encoder_norm: float | None = None # type: ignore
init_encoder_with_decoder_transpose: bool = True

l1_coefficient: float = 0.00008
l1_coefficient_warmup_steps: int = 0
lp: int = 1

use_ghost_grads: bool = True
use_ghost_grads: bool = False

def __post_init__(self):
super().__post_init__()
if self.hook_point_out is None:
self.hook_point_out = self.hook_point_in
if self.d_sae is None:
self.d_sae = self.d_model * self.expansion_factor
if self.norm_activation == "dataset-wise" and self.dataset_average_activation_norm is None:
print(f'dataset_average_activation_norm is None and norm_activation is "dataset-wise". Will be computed automatically from the dataset.')
if self.sparsity_include_decoder_norm and self.decoder_exactly_fixed_norm:
raise ValueError("sparsity_include_decoder_norm and decoder_exactly_fixed_norm are incompatible.")
if self.sparsity_include_decoder_norm and self.use_ghost_grads:
raise ValueError("sparsity_include_decoder_norm and use_ghost_grads are incompatible.")
if self.init_encoder_with_decoder_transpose and isinstance(self.init_encoder_norm, float):
raise ValueError("init_encoder_with_decoder_transpose and init_encoder_norm with float are incompatible.")


@staticmethod
def from_pretrained(pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
Expand Down Expand Up @@ -269,6 +288,8 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig):
lr_warm_up_steps: int = 5000
lr_cool_down_steps: int = 10000
train_batch_size: int = 4096
clip_grad_norm: float = 0.0
remove_gradient_parallel_to_decoder_directions: bool = False

finetuning: bool = False

Expand Down
28 changes: 20 additions & 8 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,8 @@
from lm_saes.analysis.features_to_logits import features_to_logits


def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
hf_model = AutoModelForCausalLM.from_pretrained(
(
cfg.lm.model_name
Expand Down Expand Up @@ -72,6 +65,25 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig):
model.eval()
activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store)

if (
cfg.sae.norm_activation == "dataset-wise" and cfg.sae.dataset_average_activation_norm is None
or cfg.sae.init_decoder_norm is None
):
assert not cfg.finetuning
sae = SparseAutoEncoder.from_initialization_searching(
activation_store=activation_store,
cfg=cfg,
)
else:
sae = SparseAutoEncoder.from_config(cfg=cfg.sae)

if cfg.finetuning:
# Fine-tune SAE with frozen encoder weights and bias
sae.train_finetune_for_suppression_parameters()

cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name))
cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name))

if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0):
wandb_config: dict = {
**asdict(cfg),
Expand Down
Loading