diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 62fdded..e0a7b53 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Optional, Tuple from typing_extensions import deprecated +import math import torch import torch.distributed as dist @@ -344,6 +345,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): log_frequency: int = 10 n_checkpoints: int = 10 + check_point_save_mode: str = 'log' # 'log' or 'linear' def __post_init__(self): super().__post_init__() @@ -399,6 +401,19 @@ def __post_init__(self): if self.finetuning: assert self.sae.l1_coefficient == 0.0, "L1 coefficient must be 0.0 for finetuning." + if self.n_checkpoints > 0: + if self.check_point_save_mode == 'linear': + self.checkpoint_thresholds = list( + range(0, total_training_tokens, total_training_tokens // self.n_checkpoints) + )[1:] + elif self.check_point_save_mode == 'log': + self.checkpoint_thresholds = [ + math.ceil(2 ** (i / self.n_checkpoints * math.log2(total_training_steps))) * self.effective_batch_size + for i in range(1, self.n_checkpoints) + ] + else: + raise ValueError(f"Unknown checkpoint save mode: {self.check_point_save_mode}") + @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): """ diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index 5629067..30138b3 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -13,6 +13,7 @@ from tqdm import tqdm import wandb +import math from lm_saes.activation.activation_store import ActivationStore from lm_saes.sae import SparseAutoEncoder @@ -52,11 +53,7 @@ def train_sae( n_training_tokens = 0 log_feature_sparsity = None - checkpoint_thresholds = [] - if cfg.n_checkpoints > 0: - checkpoint_thresholds = list( - range(0, total_training_tokens, total_training_tokens // cfg.n_checkpoints) - )[1:] + activation_store.initialize() if is_master(): print(f"Activation Store Initialized.") @@ -313,8 +310,8 @@ def train_sae( # Checkpoint if at checkpoint frequency if ( - len(checkpoint_thresholds) > 0 - and n_training_tokens >= checkpoint_thresholds[0] + len(cfg.checkpoint_thresholds) > 0 + and n_training_tokens >= cfg.checkpoint_thresholds[0] ): # Save the model and optimizer state path = os.path.join( @@ -325,7 +322,7 @@ def train_sae( if not cfg.sae.sparsity_include_decoder_norm: sae.set_decoder_norm_to_fixed_norm(1) sae.save_pretrained(path) - checkpoint_thresholds.pop(0) + cfg.checkpoint_thresholds.pop(0) n_training_steps += 1