Skip to content

Commit

Permalink
fix: minor type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 15, 2024
1 parent cb35d35 commit 2cdb903
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
10 changes: 6 additions & 4 deletions src/lm_saes/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ def get_scheduler(scheduler_name: Optional[str], optimizer: optim.Optimizer, **k
training_steps, num_cycles, lr_end.
"""

def get_smoothing_lambda(training_steps, gamma: float, cool_down_steps: int, lr_end: float):
def get_smoothing_lambda(
training_steps: int, warm_up_steps: int, gamma: float, cool_down_steps: int, lr_end: float
):
smooth_steps = gamma * warm_up_steps

def lr_lambda(steps):
def lr_lambda(steps: int):
if steps < smooth_steps:
return 2 * (steps + 1) / (warm_up_steps * (1 + gamma))
elif steps < warm_up_steps:
return 1 - ((steps / warm_up_steps - 1) ** 2) / (1 - gamma**2)
elif steps < cool_down_steps:
return 1
return 1.0
else:
progress = (steps - cool_down_steps) / (training_steps - cool_down_steps)
return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress))
Expand Down Expand Up @@ -93,7 +95,7 @@ def lr_lambda(steps: int):
assert training_steps is not None, "training_steps must be provided"
cool_down_steps = training_steps - int(1.5 * warm_up_steps)
assert training_steps is not None, "training_steps must be provided"
lr_lambda = get_smoothing_lambda(training_steps, 0.5, cool_down_steps, 0.0)
lr_lambda = get_smoothing_lambda(training_steps, warm_up_steps, 0.5, cool_down_steps, 0.0)
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
elif scheduler_name.lower() == "linearwarmupdecay":
warm_up_steps = kwargs.get("warm_up_steps", 0)
Expand Down
5 changes: 0 additions & 5 deletions src/lm_saes/sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def train_sae(
if cfg.sae.use_glu_encoder:
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore
sae.parallelize_plan = plan # type: ignore
sae.tensor_paralleled = True

elif cfg.sae.ddp_size > 1:
Expand Down Expand Up @@ -152,13 +151,9 @@ def train_sae(
if cfg.wandb.log_to_wandb and (is_master()):
feature_sparsity = act_freq_scores / n_frac_active_tokens
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
# wandb_histogram = wandb.Histogram(
# log_feature_sparsity.detach().cpu().float().numpy()
# )
wandb.log(
{
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
# "plots/feature_density_line_chart": wandb_histogram,
"sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(),
"sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(),
},
Expand Down

0 comments on commit 2cdb903

Please sign in to comment.