diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 808719e9..042a57fd 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -184,7 +184,7 @@ class SAEConfig(BaseModelConfig): init_encoder_with_decoder_transpose: bool = True l1_coefficient: float = 0.00008 - l1_coefficient_warmup_steps: int = 0 + l1_coefficient_warmup_steps: int | float = 0.1 lp: int = 1 use_ghost_grads: bool = False @@ -285,7 +285,7 @@ class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup, exponentialwarmup ) lr_end: Optional[float] = 1 / 32 - lr_warm_up_steps: int = 5000 + lr_warm_up_steps: int | float = 0.1 lr_cool_down_steps: int = 10000 train_batch_size: int = 4096 clip_grad_norm: float = 0.0 @@ -330,6 +330,15 @@ def __post_init__(self): total_training_steps = self.total_training_tokens // self.effective_batch_size print_once(f"Total training steps: {total_training_steps}") + if self.lr_scheduler_name == "constantwithwarmup" and isinstance(self.lr_warm_up_steps, float): + assert 0 <= self.lr_warm_up_steps <= 1.0 + self.lr_warm_up_steps = int(self.lr_warm_up_steps * total_training_steps) + print_once(f"Learning rate warm up steps: {self.lr_warm_up_steps}") + if isinstance(self.sae.l1_coefficient_warmup_steps, float): + assert 0 <= self.sae.l1_coefficient_warmup_steps <= 1.0 + self.sae.l1_coefficient_warmup_steps = int(self.sae.l1_coefficient_warmup_steps * total_training_steps) + print_once(f"L1 coefficient warm up steps: {self.sae.l1_coefficient_warmup_steps}") + @dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): """