From 2cdb9032b53d37e042d03fb9117716c389ea0ac1 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Fri, 15 Nov 2024 19:24:16 +0800 Subject: [PATCH] fix: minor type issues --- src/lm_saes/optim.py | 10 ++++++---- src/lm_saes/sae_training.py | 5 ----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/lm_saes/optim.py b/src/lm_saes/optim.py index 50a9196..387a43d 100644 --- a/src/lm_saes/optim.py +++ b/src/lm_saes/optim.py @@ -25,16 +25,18 @@ def get_scheduler(scheduler_name: Optional[str], optimizer: optim.Optimizer, **k training_steps, num_cycles, lr_end. """ - def get_smoothing_lambda(training_steps, gamma: float, cool_down_steps: int, lr_end: float): + def get_smoothing_lambda( + training_steps: int, warm_up_steps: int, gamma: float, cool_down_steps: int, lr_end: float + ): smooth_steps = gamma * warm_up_steps - def lr_lambda(steps): + def lr_lambda(steps: int): if steps < smooth_steps: return 2 * (steps + 1) / (warm_up_steps * (1 + gamma)) elif steps < warm_up_steps: return 1 - ((steps / warm_up_steps - 1) ** 2) / (1 - gamma**2) elif steps < cool_down_steps: - return 1 + return 1.0 else: progress = (steps - cool_down_steps) / (training_steps - cool_down_steps) return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress)) @@ -93,7 +95,7 @@ def lr_lambda(steps: int): assert training_steps is not None, "training_steps must be provided" cool_down_steps = training_steps - int(1.5 * warm_up_steps) assert training_steps is not None, "training_steps must be provided" - lr_lambda = get_smoothing_lambda(training_steps, 0.5, cool_down_steps, 0.0) + lr_lambda = get_smoothing_lambda(training_steps, warm_up_steps, 0.5, cool_down_steps, 0.0) return lr_scheduler.LambdaLR(optimizer, lr_lambda) elif scheduler_name.lower() == "linearwarmupdecay": warm_up_steps = kwargs.get("warm_up_steps", 0) diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index e22f128..8633f9a 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -59,7 +59,6 @@ def train_sae( if cfg.sae.use_glu_encoder: plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) sae = parallelize_module(sae, device_mesh=sae.device_mesh["tp"], parallelize_plan=plan) # type: ignore - sae.parallelize_plan = plan # type: ignore sae.tensor_paralleled = True elif cfg.sae.ddp_size > 1: @@ -152,13 +151,9 @@ def train_sae( if cfg.wandb.log_to_wandb and (is_master()): feature_sparsity = act_freq_scores / n_frac_active_tokens log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) - # wandb_histogram = wandb.Histogram( - # log_feature_sparsity.detach().cpu().float().numpy() - # ) wandb.log( { "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(), - # "plots/feature_density_line_chart": wandb_histogram, "sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(), "sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(), },